This commit is contained in:
duanchao.bill
2022-04-27 11:51:12 +08:00
parent d1a8835b1e
commit 3a3d48228b
17 changed files with 344 additions and 212 deletions

View File

@@ -7,8 +7,16 @@ from httprunner.runner import HttpRunner
from httprunner.step import Step
from httprunner.step_request import RunRequest
from httprunner.step_testcase import RunTestCase
from httprunner.step_sql_request import RunSqlRequest, StepSqlRequestValidation, StepSqlRequestExtraction
from httprunner.step_thrift_request import RunThriftRequest, StepThriftRequestValidation, StepThriftRequestExtraction
from httprunner.step_sql_request import (
RunSqlRequest,
StepSqlRequestValidation,
StepSqlRequestExtraction,
)
from httprunner.step_thrift_request import (
RunThriftRequest,
StepThriftRequestValidation,
StepThriftRequestExtraction,
)
__all__ = [
"__version__",

View File

@@ -43,8 +43,7 @@ class TestCompat(unittest.TestCase):
"body.data.buildings[0].building_id",
)
self.assertEqual(
compat._convert_jmespath("body.users[-1]"),
"body.users[-1]",
compat._convert_jmespath("body.users[-1]"), "body.users[-1]",
)
self.assertEqual(
compat._convert_jmespath("body.result.WorkNode_-1"),

View File

@@ -27,7 +27,7 @@ class DBEngine(object):
"""
for k, v in row.items():
if isinstance(v, datetime.datetime):
row[k] = v.strftime('%Y-%m-%d %H:%M:%S')
row[k] = v.strftime("%Y-%m-%d %H:%M:%S")
elif isinstance(v, datetime.date):
row[k] = v.strftime("%Y-%m-%d")
elif isinstance(v, str):
@@ -73,7 +73,6 @@ class DBEngine(object):
def update(self, query, commit=True):
return self._fetch(query=query, commit=commit)
if __name__ == '__main__':
db = DBEngine(
f"mysql+pymysql://xxxxx:xxxxx@10.0.0.1:3306/dbname?charset=utf8mb4")
if __name__ == "__main__":
db = DBEngine(f"mysql+pymysql://xxxxx:xxxxx@10.0.0.1:3306/dbname?charset=utf8mb4")

View File

@@ -97,11 +97,7 @@ class TestLoader(unittest.TestCase):
)
def test_load_env_path_not_exist(self):
dot_env_path = os.path.join(
os.getcwd(),
"tests",
"data",
)
dot_env_path = os.path.join(os.getcwd(), "tests", "data",)
env_variables_mapping = loader.load_dot_env_file(dot_env_path)
self.assertEqual(env_variables_mapping, {})

View File

@@ -534,8 +534,7 @@ def main_make(tests_paths: List[Text]) -> List[Text]:
def init_make_parser(subparsers):
"""make testcases: parse command line options and run commands."""
parser = subparsers.add_parser(
"make",
help="Convert YAML/JSON testcases to pytest cases.",
"make", help="Convert YAML/JSON testcases to pytest cases.",
)
parser.add_argument(
"testcase_path", nargs="*", help="Specify YAML/JSON testcase file/folder path"

View File

@@ -73,8 +73,7 @@ from request_methods.request_with_functions_test import (
content,
)
self.assertIn(
".call(RequestWithFunctions)",
content,
".call(RequestWithFunctions)", content,
)
def test_make_testcase_folder(self):
@@ -112,8 +111,7 @@ from request_methods.request_with_functions_test import (
)
loader.project_meta = None
self.assertEqual(
ensure_file_abs_path_valid(os.getcwd()),
os.getcwd(),
ensure_file_abs_path_valid(os.getcwd()), os.getcwd(),
)
loader.project_meta = None
self.assertEqual(
@@ -124,17 +122,11 @@ from request_methods.request_with_functions_test import (
def test_convert_testcase_path(self):
self.assertEqual(
convert_testcase_path(os.path.join(self.data_dir, "a-b.c", "2 3.yml")),
(
os.path.join(self.data_dir, "a_b_c", "T2_3_test.py"),
"T23",
),
(os.path.join(self.data_dir, "a_b_c", "T2_3_test.py"), "T23",),
)
self.assertEqual(
convert_testcase_path(os.path.join(self.data_dir, "a-b.c", "中文case.yml")),
(
os.path.join(self.data_dir, "a_b_c", "中文case_test.py"),
"中文Case",
),
(os.path.join(self.data_dir, "a_b_c", "中文case_test.py"), "中文Case",),
)
def test_make_config_chain_style(self):
@@ -153,11 +145,7 @@ from request_methods.request_with_functions_test import (
def test_make_teststep_chain_style(self):
step = {
"name": "get with params",
"variables": {
"foo1": "bar1",
"foo2": 123,
"sum_v": "${sum_two(1, 2)}",
},
"variables": {"foo1": "bar1", "foo2": 123, "sum_v": "${sum_two(1, 2)}",},
"request": {
"method": "GET",
"url": "/get",

View File

@@ -77,10 +77,11 @@ class TransportEnum(Text, Enum):
class TThriftRequest(BaseModel):
""" rpc request model"""
method: Text = ''
method: Text = ""
params: Dict = {}
thrift_client: Any = None
idl_path: Text = '' # idl local path
idl_path: Text = "" # idl local path
timeout: int = 10 # sec
transport: TransportEnum = TransportEnum.BUFFERED
include_dirs: List[Union[Text, None]] = [] # param of thriftpy2.load
@@ -106,6 +107,7 @@ class SqlMethodEnum(Text, Enum):
class TSqlRequest(BaseModel):
""" sql request model"""
db_config: TConfigDB = TConfigDB()
method: SqlMethodEnum = None
sql: Text = None

View File

@@ -476,9 +476,7 @@ def parse_variables_mapping(
return parsed_variables
def parse_parameters(
parameters: Dict,
) -> List[Dict]:
def parse_parameters(parameters: Dict,) -> List[Dict]:
"""parse parameters and generate cartesian product.
Args:

View File

@@ -124,20 +124,17 @@ class ResponseObjectBase(object):
self.parser = parser
self.validation_results: Dict = {}
def extract(self,
extractors: Dict[Text, Text],
variables_mapping: VariablesMapping = None,
) -> Dict[Text, Any]:
def extract(
self, extractors: Dict[Text, Text], variables_mapping: VariablesMapping = None,
) -> Dict[Text, Any]:
if not extractors:
return {}
extract_mapping = {}
for key, field in extractors.items():
if '$' in field:
if "$" in field:
# field contains variable or function
field = self.parser.parse_data(
field, variables_mapping
)
field = self.parser.parse_data(field, variables_mapping)
field_value = self._search_jmespath(field)
extract_mapping[key] = field_value
@@ -148,9 +145,7 @@ class ResponseObjectBase(object):
raise NotImplementedError("_search_jmespath not override")
def validate(
self,
validators: Validators,
variables_mapping: VariablesMapping = None,
self, validators: Validators, variables_mapping: VariablesMapping = None,
):
variables_mapping = variables_mapping or {}
@@ -173,9 +168,7 @@ class ResponseObjectBase(object):
check_item = u_validator["check"]
if "$" in check_item:
# check_item is variable or function
check_item = self.parser.parse_data(
check_item, variables_mapping
)
check_item = self.parser.parse_data(check_item, variables_mapping)
check_item = parse_string_value(check_item)
if check_item and isinstance(check_item, Text):

View File

@@ -61,9 +61,6 @@ class TestResponse(unittest.TestCase):
def test_validate_functions(self):
variables_mapping = {"index": 1}
self.resp_obj.validate(
[
{"eq": ["${get_num(0)}", 0]},
{"eq": ["${get_num($index)}", 1]},
],
[{"eq": ["${get_num(0)}", 0]}, {"eq": ["${get_num($index)}", 1]},],
variables_mapping=variables_mapping,
)

View File

@@ -92,7 +92,7 @@ class SessionRunner(object):
def with_thrift_client(self, thrift_client) -> "SessionRunner":
self.thrift_client = thrift_client
def with_db_engine(self,db_engine):
def with_db_engine(self, db_engine):
self.db_engine = db_engine
def __parse_config(self, param: Dict = None) -> None:

View File

@@ -8,7 +8,11 @@ from httprunner.step_request import (
StepRequestValidation,
)
from httprunner.step_testcase import StepRefCase
from httprunner.step_sql_request import RunSqlRequest, StepSqlRequestValidation, StepSqlRequestExtraction
from httprunner.step_sql_request import (
RunSqlRequest,
StepSqlRequestValidation,
StepSqlRequestExtraction,
)
class Step(object):

View File

@@ -67,10 +67,7 @@ def call_hooks(
def run_step_request(runner: HttpRunner, step: TStep) -> StepResult:
"""run teststep: request"""
step_result = StepResult(
name=step.name,
success=False,
)
step_result = StepResult(name=step.name, success=False,)
start_time = time.time()
step.variables = runner.merge_step_variables(step.variables)
@@ -82,8 +79,7 @@ def run_step_request(runner: HttpRunner, step: TStep) -> StepResult:
request_dict.pop("upload", None)
parsed_request_dict = runner.parser.parse_data(request_dict, step.variables)
parsed_request_dict["headers"].setdefault(
"HRUN-Request-ID",
f"HRUN-{runner.case_id}-{str(int(time.time() * 1000))[-6:]}",
"HRUN-Request-ID", f"HRUN-{runner.case_id}-{str(int(time.time() * 1000))[-6:]}",
)
step.variables["request"] = parsed_request_dict

View File

@@ -10,7 +10,11 @@ from httprunner.models import IStep, StepResult, TStep
from httprunner.models import TSqlRequest, SqlMethodEnum
from httprunner.response import SqlResponseObject
from httprunner.runner import HttpRunner
from httprunner.step_request import call_hooks, StepRequestExtraction, StepRequestValidation
from httprunner.step_request import (
call_hooks,
StepRequestExtraction,
StepRequestValidation,
)
from httprunner.database.engine import DBEngine
from httprunner.exceptions import SqlMethodNotSupport
@@ -19,32 +23,42 @@ def run_step_sql_request(runner: HttpRunner, step: TStep) -> StepResult:
"""run teststep:sql request"""
start_time = time.time()
step_result = StepResult(
name=step.name,
success=False,
)
step_result = StepResult(name=step.name, success=False,)
step.variables = runner.merge_step_variables(step.variables)
# parse
request_dict = step.sql_request.dict()
parsed_request_dict = runner.parser.parse_data(
request_dict, step.variables
)
parsed_request_dict = runner.parser.parse_data(request_dict, step.variables)
config = runner.get_config()
parsed_request_dict["db_config"]["psm"] = parsed_request_dict["db_config"]["psm"] or config.db.psm
parsed_request_dict["db_config"]["user"] = parsed_request_dict["db_config"]["user"] or config.db.user
parsed_request_dict["db_config"]["password"] = parsed_request_dict["db_config"]["password"] or config.db.password
parsed_request_dict["db_config"]["ip"] = parsed_request_dict["db_config"]["ip"] or config.db.ip
parsed_request_dict["db_config"]["port"] = parsed_request_dict["db_config"]["port"] or config.db.port
parsed_request_dict["db_config"]["database"] = parsed_request_dict["db_config"]["database"] or config.db.database
parsed_request_dict["db_config"]["psm"] = (
parsed_request_dict["db_config"]["psm"] or config.db.psm
)
parsed_request_dict["db_config"]["user"] = (
parsed_request_dict["db_config"]["user"] or config.db.user
)
parsed_request_dict["db_config"]["password"] = (
parsed_request_dict["db_config"]["password"] or config.db.password
)
parsed_request_dict["db_config"]["ip"] = (
parsed_request_dict["db_config"]["ip"] or config.db.ip
)
parsed_request_dict["db_config"]["port"] = (
parsed_request_dict["db_config"]["port"] or config.db.port
)
parsed_request_dict["db_config"]["database"] = (
parsed_request_dict["db_config"]["database"] or config.db.database
)
if parsed_request_dict["db_config"]["psm"]:
runner.db_engine = DBEngine(f'mysql+pymysql://:@/?charset=utf8mb4&db_psm={parsed_request_dict["psm"]}')
runner.db_engine = DBEngine(
f'mysql+pymysql://:@/?charset=utf8mb4&db_psm={parsed_request_dict["psm"]}'
)
else:
runner.db_engine = DBEngine(
f'mysql+pymysql://{parsed_request_dict["db_config"]["user"]}:'
f'{parsed_request_dict["db_config"]["password"]}@{parsed_request_dict["db_config"]["ip"]}:'
f'{parsed_request_dict["db_config"]["port"]}/{parsed_request_dict["db_config"]["database"]}'
f'?charset=utf8mb4')
f"?charset=utf8mb4"
)
# parsed_request_dict["headers"].setdefault(
# "HRUN-Request-ID",
@@ -57,19 +71,23 @@ def run_step_sql_request(runner: HttpRunner, step: TStep) -> StepResult:
logger.info(f"Executing SQL: {parsed_request_dict['sql']}")
if step.sql_request.method == SqlMethodEnum.FETCHONE:
sql_resp = runner.db_engine.fetchone(parsed_request_dict['sql'])
sql_resp = runner.db_engine.fetchone(parsed_request_dict["sql"])
elif step.sql_request.method == SqlMethodEnum.INSERT:
sql_resp = runner.db_engine.insert(parsed_request_dict['sql'])
sql_resp = runner.db_engine.insert(parsed_request_dict["sql"])
elif step.sql_request.method == SqlMethodEnum.FETCHMANY:
sql_resp = runner.db_engine.fetchmany(parsed_request_dict['sql'], parsed_request_dict['size'])
sql_resp = runner.db_engine.fetchmany(
parsed_request_dict["sql"], parsed_request_dict["size"]
)
elif step.sql_request.method == SqlMethodEnum.FETCHALL:
sql_resp = runner.db_engine.fetchall(parsed_request_dict['sql'])
sql_resp = runner.db_engine.fetchall(parsed_request_dict["sql"])
elif step.sql_request.method == SqlMethodEnum.UPDATE:
sql_resp = runner.db_engine.update(parsed_request_dict['sql'])
sql_resp = runner.db_engine.update(parsed_request_dict["sql"])
elif step.sql_request.method == SqlMethodEnum.DELETE:
sql_resp = runner.db_engine.delete(parsed_request_dict['sql'])
sql_resp = runner.db_engine.delete(parsed_request_dict["sql"])
else:
raise SqlMethodNotSupport(f"step.sql_request.method {parsed_request_dict['method']} not support")
raise SqlMethodNotSupport(
f"step.sql_request.method {parsed_request_dict['method']} not support"
)
resp_obj = SqlResponseObject(sql_resp, parser=runner.parser)
step.variables["sql_response"] = resp_obj
@@ -107,9 +125,7 @@ def run_step_sql_request(runner: HttpRunner, step: TStep) -> StepResult:
# validate
validators = step.validators
try:
resp_obj.validate(
validators, variables_mapping
)
resp_obj.validate(validators, variables_mapping)
step_result.success = True
except ValidationFailure:
log_sql_req_resp_details()
@@ -128,7 +144,7 @@ class StepSqlRequestValidation(StepRequestValidation):
def __init__(self, step: TStep):
self.__step = step
super().__init__(step)
def run(self, runner: HttpRunner):
return run_step_sql_request(runner, self.__step)
@@ -154,7 +170,9 @@ class RunSqlRequest(IStep):
self.__step.variables.update(variables)
return self
def with_db_config(self, psm=None, user=None, password=None, ip=None, port=None, database=None):
def with_db_config(
self, psm=None, user=None, password=None, ip=None, port=None, database=None
):
if psm:
self.__step.sql_request.db_config.psm = psm
if user:
@@ -205,7 +223,9 @@ class RunSqlRequest(IStep):
self.__step.retry_interval = retry_interval
return self
def teardown_hook(self, hook: Text, assign_var_name: Text = None) -> "RunSqlRequest":
def teardown_hook(
self, hook: Text, assign_var_name: Text = None
) -> "RunSqlRequest":
if assign_var_name:
self.__step.teardown_hooks.append({assign_var_name: hook})
else:
@@ -239,6 +259,8 @@ class RunSqlRequest(IStep):
def validate(self) -> StepSqlRequestValidation:
return StepSqlRequestValidation(self.__step)
def with_jmespath(self, jmes_path: Text, var_name: Text) -> "StepSqlRequestExtraction":
def with_jmespath(
self, jmes_path: Text, var_name: Text
) -> "StepSqlRequestExtraction":
self.__step.extract[var_name] = jmes_path
return StepSqlRequestExtraction(self.__step)

View File

@@ -7,7 +7,11 @@ from httprunner import utils
from httprunner.exceptions import ValidationFailure
from httprunner.models import IStep, StepResult, TStep, ProtoType, TransType
from httprunner.runner import HttpRunner
from httprunner.step_request import call_hooks, StepRequestExtraction, StepRequestValidation
from httprunner.step_request import (
call_hooks,
StepRequestExtraction,
StepRequestValidation,
)
from httprunner.models import TThriftRequest
from httprunner.response import ThriftResponseObject
@@ -18,29 +22,40 @@ def run_step_thrift_request(runner: HttpRunner, step: TStep) -> StepResult:
"""run teststep:thrift request"""
start_time = time.time()
step_result = StepResult(
name=step.name,
success=False,
)
step_result = StepResult(name=step.name, success=False,)
step.variables = runner.merge_step_variables(step.variables)
# parse
request_dict = step.thrift_request.dict()
parsed_request_dict = runner.parser.parse_data(
request_dict, step.variables
)
parsed_request_dict = runner.parser.parse_data(request_dict, step.variables)
config = runner.get_config()
parsed_request_dict["psm"] = parsed_request_dict["psm"] or config.thrift.psm
parsed_request_dict["env"] = parsed_request_dict["env"] or config.thrift.env
parsed_request_dict["cluster"] = parsed_request_dict["cluster"] or config.thrift.cluster
parsed_request_dict["idl_path"] = parsed_request_dict["idl_path"] or config.thrift.idl_path
parsed_request_dict["include_dirs"] = parsed_request_dict["include_dirs"] or config.thrift.include_dirs
parsed_request_dict["method"] = parsed_request_dict["method"] or config.thrift.method
parsed_request_dict["service_name"] = parsed_request_dict["service_name"] or config.thrift.service_name
parsed_request_dict["cluster"] = (
parsed_request_dict["cluster"] or config.thrift.cluster
)
parsed_request_dict["idl_path"] = (
parsed_request_dict["idl_path"] or config.thrift.idl_path
)
parsed_request_dict["include_dirs"] = (
parsed_request_dict["include_dirs"] or config.thrift.include_dirs
)
parsed_request_dict["method"] = (
parsed_request_dict["method"] or config.thrift.method
)
parsed_request_dict["service_name"] = (
parsed_request_dict["service_name"] or config.thrift.service_name
)
parsed_request_dict["ip"] = parsed_request_dict["ip"] or config.thrift.ip
parsed_request_dict["port"] = parsed_request_dict["port"] or config.thrift.port
parsed_request_dict["proto_type"] = parsed_request_dict["proto_type"] or config.thrift.proto_type
parsed_request_dict["trans_port"] = parsed_request_dict["trans_type"] or config.thrift.trans_type
parsed_request_dict["timeout"] = parsed_request_dict["timeout"] or config.thrift.timeout
parsed_request_dict["proto_type"] = (
parsed_request_dict["proto_type"] or config.thrift.proto_type
)
parsed_request_dict["trans_port"] = (
parsed_request_dict["trans_type"] or config.thrift.trans_type
)
parsed_request_dict["timeout"] = (
parsed_request_dict["timeout"] or config.thrift.timeout
)
parsed_request_dict["thrift_client"] = parsed_request_dict["thrift_client"]
# parsed_request_dict["headers"].setdefault(
@@ -53,17 +68,24 @@ def run_step_thrift_request(runner: HttpRunner, step: TStep) -> StepResult:
runner.thrift_client = parsed_request_dict["thrift_client"]
if not runner.thrift_client:
runner.thrift_client = ThriftClient(parsed_request_dict["idl_path"], parsed_request_dict["service_name"],
parsed_request_dict["ip"], parsed_request_dict["port"],
parsed_request_dict["timeout"], parsed_request_dict["proto_type"],
parsed_request_dict["trans_port"])
runner.thrift_client = ThriftClient(
parsed_request_dict["idl_path"],
parsed_request_dict["service_name"],
parsed_request_dict["ip"],
parsed_request_dict["port"],
parsed_request_dict["timeout"],
parsed_request_dict["proto_type"],
parsed_request_dict["trans_port"],
)
# setup hooks
if step.setup_hooks:
call_hooks(runner, step.setup_hooks, step.variables, "setup request")
# thrift request
resp = runner.thrift_client.send_request(parsed_request_dict["params"], parsed_request_dict["method"])
resp = runner.thrift_client.send_request(
parsed_request_dict["params"], parsed_request_dict["method"]
)
resp_obj = ThriftResponseObject(resp, parser=runner.parser)
step.variables["thrift_response"] = resp_obj
@@ -72,7 +94,9 @@ def run_step_thrift_request(runner: HttpRunner, step: TStep) -> StepResult:
call_hooks(runner, step.teardown_hooks, step.variables, "teardown request")
def log_thrift_req_resp_details():
err_msg = "\n{} THRIFT DETAILED REQUEST & RESPONSE {}\n".format("*" * 32, "*" * 32)
err_msg = "\n{} THRIFT DETAILED REQUEST & RESPONSE {}\n".format(
"*" * 32, "*" * 32
)
# log request
err_msg += "====== thrift request details ======\n"
@@ -101,9 +125,7 @@ def run_step_thrift_request(runner: HttpRunner, step: TStep) -> StepResult:
# validate
validators = step.validators
try:
resp_obj.validate(
validators, variables_mapping
)
resp_obj.validate(validators, variables_mapping)
step_result.success = True
except ValidationFailure:
log_thrift_req_resp_details()
@@ -153,7 +175,9 @@ class RunThriftRequest(IStep):
self.__step.retry_interval = retry_interval
return self
def teardown_hook(self, hook: Text, assign_var_name: Text = None) -> "RunThriftRequest":
def teardown_hook(
self, hook: Text, assign_var_name: Text = None
) -> "RunThriftRequest":
if assign_var_name:
self.__step.teardown_hooks.append({assign_var_name: hook})
else:
@@ -182,11 +206,13 @@ class RunThriftRequest(IStep):
self.__step.thrift_request.include_dirs = [idl_root_path]
return self
def with_thrift_client(self, thrift_client: Union["ThriftClient", str]) -> "RunThriftRequest":
def with_thrift_client(
self, thrift_client: Union["ThriftClient", str]
) -> "RunThriftRequest":
self.__step.thrift_request.thrift_client = thrift_client
return self
def with_ip(self,ip: str) -> "RunThriftRequest":
def with_ip(self, ip: str) -> "RunThriftRequest":
self.__step.thrift_request.ip = ip
return self
@@ -194,11 +220,11 @@ class RunThriftRequest(IStep):
self.__step.thrift_request.port = port
return self
def with_proto_type(self,proto_type:ProtoType) -> "RunThriftRequest":
def with_proto_type(self, proto_type: ProtoType) -> "RunThriftRequest":
self.__step.thrift_request.proto_type = proto_type
return self
def with_trans_type(self,trans_type:TransType) -> "RunThriftRequest":
def with_trans_type(self, trans_type: TransType) -> "RunThriftRequest":
self.__step.thrift_request.proto_type = trans_type
return self
@@ -220,6 +246,8 @@ class RunThriftRequest(IStep):
def validate(self) -> StepThriftRequestValidation:
return StepThriftRequestValidation(self.__step)
def with_jmespath(self, jmes_path: Text, var_name: Text) -> "StepThriftRequestExtraction":
def with_jmespath(
self, jmes_path: Text, var_name: Text
) -> "StepThriftRequestExtraction":
self.__step.extract[var_name] = jmes_path
return StepThriftRequestExtraction(self.__step)

View File

@@ -19,21 +19,21 @@ text_characters = "".join(map(chr, range(32, 127))) + "\n\r\t\b"
_null_trans = str.maketrans("", "")
ESCAPE = re.compile(r'[\x00-\x1f\\"\b\f\n\r\t]')
ESCAPE_ASCII = re.compile(r'([\\"]|[^\ -~])')
HAS_UTF8 = re.compile(r'[\x80-\xff]')
HAS_UTF8 = re.compile(r"[\x80-\xff]")
ESCAPE_DCT = {
'\\': '\\\\',
"\\": "\\\\",
'"': '\\"',
'\b': '\\b',
'\f': '\\f',
'\n': '\\n',
'\r': '\\r',
'\t': '\\t',
"\b": "\\b",
"\f": "\\f",
"\n": "\\n",
"\r": "\\r",
"\t": "\\t",
}
for i in range(0x20):
ESCAPE_DCT.setdefault(chr(i), '\\u{0:04x}'.format(i))
ESCAPE_DCT.setdefault(chr(i), "\\u{0:04x}".format(i))
# ESCAPE_DCT.setdefault(chr(i), '\\u%04x' % (i,))
INFINITY = float('inf')
INFINITY = float("inf")
FLOAT_REPR = repr
@@ -66,7 +66,7 @@ def unicode_2_utf8_keep_native(para):
elif type(para) is tuple:
return tuple(unicode_2_utf8_keep_native(list(para)))
elif type(para) is str:
return para.encode('utf-8')
return para.encode("utf-8")
else:
logging.debug("type========", type(para))
# if issubclass(type(para), dict):
@@ -93,7 +93,7 @@ def py_encode_basestring_ascii(s):
"""
if isinstance(s, str) and HAS_UTF8.search(s) is not None:
s = s.decode('utf-8')
s = s.decode("utf-8")
def replace(match):
s = match.group(0)
@@ -102,27 +102,25 @@ def py_encode_basestring_ascii(s):
except KeyError:
n = ord(s)
if n < 0x10000:
return '\\u{0:04x}'.format(n)
return "\\u{0:04x}".format(n)
# return '\\u%04x' % (n,)
else:
# surrogate pair
n -= 0x10000
s1 = 0xd800 | ((n >> 10) & 0x3ff)
s2 = 0xdc00 | (n & 0x3ff)
return '\\u{0:04x}\\u{1:04x}'.format(s1, s2)
s1 = 0xD800 | ((n >> 10) & 0x3FF)
s2 = 0xDC00 | (n & 0x3FF)
return "\\u{0:04x}\\u{1:04x}".format(s1, s2)
# return '\\u%04x\\u%04x' % (s1, s2)
return '"' + str(ESCAPE_ASCII.sub(replace, s)) + '"'
encode_basestring_ascii = (
c_encode_basestring_ascii or py_encode_basestring_ascii)
encode_basestring_ascii = c_encode_basestring_ascii or py_encode_basestring_ascii
class ThriftJSONDecoder(json.JSONDecoder):
def __init__(self, *args, **kwargs):
self._thrift_class = kwargs.pop('thrift_class')
self._thrift_class = kwargs.pop("thrift_class")
super(ThriftJSONDecoder, self).__init__(*args, **kwargs)
def decode(self, json_str):
@@ -130,9 +128,12 @@ class ThriftJSONDecoder(json.JSONDecoder):
dct = json_str
else:
dct = super(ThriftJSONDecoder, self).decode(json_str)
return self._convert(dct, TType.STRUCT,
# (self._thrift_class, self._thrift_class.thrift_spec))
self._thrift_class)
return self._convert(
dct,
TType.STRUCT,
# (self._thrift_class, self._thrift_class.thrift_spec))
self._thrift_class,
)
def _convert(self, val, ttype, ttype_info):
if ttype == TType.STRUCT:
@@ -156,7 +157,9 @@ class ThriftJSONDecoder(json.JSONDecoder):
if val is None or field_name not in val:
continue
converted_val = self._convert(val[field_name], field_ttype, field_ttype_info)
converted_val = self._convert(
val[field_name], field_ttype, field_ttype_info
)
setattr(ret, field_name, converted_val)
elif ttype == TType.LIST:
if type(ttype_info) != tuple: # 说明是基础类型了, 无法在细分
@@ -174,7 +177,9 @@ class ThriftJSONDecoder(json.JSONDecoder):
else:
(element_ttype, element_ttype_info) = ttype_info
if val is not None:
ret = set([self._convert(x, element_ttype, element_ttype_info) for x in val])
ret = set(
[self._convert(x, element_ttype, element_ttype_info) for x in val]
)
else:
ret = None
@@ -193,8 +198,15 @@ class ThriftJSONDecoder(json.JSONDecoder):
val_ttype, val_ttype_info = ttype_info[1]
if val is not None:
ret = dict([(self._convert(k, key_ttype, key_ttype_info),
self._convert(v, val_ttype, val_ttype_info)) for (k, v) in val.items()])
ret = dict(
[
(
self._convert(k, key_ttype, key_ttype_info),
self._convert(v, val_ttype, val_ttype_info),
)
for (k, v) in val.items()
]
)
else:
ret = None
elif ttype == TType.STRING:
@@ -228,13 +240,15 @@ class ThriftJSONDecoder(json.JSONDecoder):
else:
ret = None
else:
raise TypeError('Unrecognized thrift field type: %s' % ttype)
raise TypeError("Unrecognized thrift field type: %s" % ttype)
return ret
def json2thrift(json_str, thrift_class):
logging.debug(json_str)
return json.loads(json_str, cls=ThriftJSONDecoder, thrift_class=thrift_class, strict=False)
return json.loads(
json_str, cls=ThriftJSONDecoder, thrift_class=thrift_class, strict=False
)
def dumper(obj):
@@ -245,14 +259,33 @@ def dumper(obj):
class MyJSONEncoder(json.JSONEncoder):
def __init__(self, skipkeys=False, ensure_ascii=True, check_circular=True,
allow_nan=True, indent=None, separators=None,
encoding='utf-8', default=None, sort_keys=False, **kw):
super(MyJSONEncoder, self).__init__(skipkeys=skipkeys, ensure_ascii=ensure_ascii,
check_circular=check_circular, allow_nan=allow_nan, indent=indent,
separators=separators, encoding=encoding, default=default,
sort_keys=sort_keys)
self.skip_nonutf8_value = kw.get('skip_nonutf8_value', False) # 默认不skip忽略非utf-8编码的字段
def __init__(
self,
skipkeys=False,
ensure_ascii=True,
check_circular=True,
allow_nan=True,
indent=None,
separators=None,
encoding="utf-8",
default=None,
sort_keys=False,
**kw
):
super(MyJSONEncoder, self).__init__(
skipkeys=skipkeys,
ensure_ascii=ensure_ascii,
check_circular=check_circular,
allow_nan=allow_nan,
indent=indent,
separators=separators,
encoding=encoding,
default=default,
sort_keys=sort_keys,
)
self.skip_nonutf8_value = kw.get(
"skip_nonutf8_value", False
) # 默认不skip忽略非utf-8编码的字段
def encode(self, o):
"""Return a JSON string representation of a Python data structure.
@@ -266,8 +299,7 @@ class MyJSONEncoder(json.JSONEncoder):
if isinstance(o, str):
_encoding = self.encoding
if (_encoding is not None
and not (_encoding == 'utf-8')):
if _encoding is not None and not (_encoding == "utf-8"):
o = o.decode(_encoding)
if self.ensure_ascii:
return encode_basestring_ascii(o)
@@ -288,10 +320,10 @@ class MyJSONEncoder(json.JSONEncoder):
tmp_chunks.append(unicode_2_utf8_keep_native(chunk))
except Exception as err:
logging.debug(traceback.format_exc())
return ''.join(tmp_chunks)
return "".join(tmp_chunks)
# 保留老的逻辑, /usr/lib/python2.7/package/json/__init__.py dumps接口
return ''.join(chunks)
return "".join(chunks)
class ThriftJSONEncoder(json.JSONEncoder):
@@ -299,13 +331,32 @@ class ThriftJSONEncoder(json.JSONEncoder):
add by braver(Braver@bytedance.com)
"""
def __init__(self, skipkeys=False, ensure_ascii=True, check_circular=True,
allow_nan=True, indent=None, separators=None, default=None, sort_keys=False, **kw):
def __init__(
self,
skipkeys=False,
ensure_ascii=True,
check_circular=True,
allow_nan=True,
indent=None,
separators=None,
default=None,
sort_keys=False,
**kw
):
super(ThriftJSONEncoder, self).__init__(skipkeys=skipkeys, ensure_ascii=ensure_ascii,
check_circular=check_circular, allow_nan=allow_nan, indent=indent,
separators=separators, default=default, sort_keys=sort_keys)
self.skip_nonutf8_value = kw.get('skip_nonutf8_value', False) # 默认不skip忽略非utf-8编码的字段
super(ThriftJSONEncoder, self).__init__(
skipkeys=skipkeys,
ensure_ascii=ensure_ascii,
check_circular=check_circular,
allow_nan=allow_nan,
indent=indent,
separators=separators,
default=default,
sort_keys=sort_keys,
)
self.skip_nonutf8_value = kw.get(
"skip_nonutf8_value", False
) # 默认不skip忽略非utf-8编码的字段
def encode(self, o):
"""Return a JSON string representation of a Python data structure.
@@ -318,8 +369,7 @@ class ThriftJSONEncoder(json.JSONEncoder):
if isinstance(o, str):
if isinstance(o, str):
_encoding = self.encoding
if (_encoding is not None
and not (_encoding == 'utf-8')):
if _encoding is not None and not (_encoding == "utf-8"):
o = o.decode(_encoding)
if self.ensure_ascii:
return encode_basestring_ascii(o)
@@ -340,18 +390,18 @@ class ThriftJSONEncoder(json.JSONEncoder):
tmp_chunks.append(unicode_2_utf8_keep_native(chunk))
except Exception as err:
logging.debug(traceback.format_exc())
return ''.join(tmp_chunks)
return "".join(tmp_chunks)
# 保留老的逻辑, /usr/lib/python2.7/package/json/__init__.py dumps接口
return ''.join(chunks)
return "".join(chunks)
def default(self, o):
if isinstance(o, bytes):
return str(o, encoding='utf-8')
if not hasattr(o, 'thrift_spec'):
return str(o, encoding="utf-8")
if not hasattr(o, "thrift_spec"):
return super(ThriftJSONEncoder, self).default(o)
spec = getattr(o, 'thrift_spec')
spec = getattr(o, "thrift_spec")
ret = {}
for tag, field in spec.items():
if field is None:
@@ -370,30 +420,42 @@ class ThriftJSONEncoder(json.JSONEncoder):
val = list(val) # 统一转成数组(list/set)
is_need_binary_bs64 = False
if type(field_ttype_info) != tuple: # 基础类型
if field_ttype_info in [TType.BYTE] and type(val[0]) in [str] and not istext(
val[0]):
if (
field_ttype_info in [TType.BYTE]
and type(val[0]) in [str]
and not istext(val[0])
):
is_need_binary_bs64 = True
if is_need_binary_bs64:
for index, item in enumerate(val):
if item and type(item) in [str] and not istext(item):
val[index] = base64.b64encode(item) # 判断为二进制字符串, 需要进行base64编码
if field_type in [TType.BYTE] and type(val) in [str]: # 说明是string(明文string或者binary)
val[index] = base64.b64encode(
item
) # 判断为二进制字符串, 需要进行base64编码
if field_type in [TType.BYTE] and type(val) in [
str
]: # 说明是string(明文string或者binary)
# 需要对二进制字节字符串字段进行base64编码, 将二进制字节串字段->ascii字符编码的base64编码明文串
if val and not istext(val): # 说明是该字段非空且为binary string
print('4' * 100, val)
val = base64.b64encode(val.encode('utf-8'))
print("4" * 100, val)
val = base64.b64encode(val.encode("utf-8"))
# val = base64.b64encode(val) # 进行base64编码处理, 不然该字段序列化为json时会报错
# if val != default:
ret[field_name] = val
if 'request_id' in o.__dict__:
ret['request_id'] = o.__dict__['request_id']
if 'rpc_latency' in o.__dict__:
ret['rpc_latency'] = o.__dict__['rpc_latency']
if "request_id" in o.__dict__:
ret["request_id"] = o.__dict__["request_id"]
if "rpc_latency" in o.__dict__:
ret["rpc_latency"] = o.__dict__["rpc_latency"]
return ret
def thrift2json(obj, skip_nonutf8_value=False):
return json.dumps(obj, cls=ThriftJSONEncoder, ensure_ascii=False, skip_nonutf8_value=skip_nonutf8_value)
return json.dumps(
obj,
cls=ThriftJSONEncoder,
ensure_ascii=False,
skip_nonutf8_value=skip_nonutf8_value,
)
def thrift2dict(obj):
@@ -403,8 +465,11 @@ def thrift2dict(obj):
dict2thrift = json2thrift
if __name__ == '__main__':
if __name__ == "__main__":
print(istext("Всего за {$price$}, а доставка - бесплатно!"))
print(istext(b'\xe4\xb8\xad\xe6\x96\x87'))
print(istext(
'{"web_uri":"ad-site-i18n-sg/202103185d0d723d88b7f642452dac73","height":336,"width":336,"file_name":""}'))
print(istext(b"\xe4\xb8\xad\xe6\x96\x87"))
print(
istext(
'{"web_uri":"ad-site-i18n-sg/202103185d0d723d88b7f642452dac73","height":336,"width":336,"file_name":""}'
)
)

View File

@@ -5,11 +5,19 @@ import json
from loguru import logger
import thriftpy2
from thriftpy2.protocol import (TBinaryProtocolFactory, TCompactProtocolFactory, TCyBinaryProtocolFactory,
TJSONProtocolFactory)
from thriftpy2.protocol import (
TBinaryProtocolFactory,
TCompactProtocolFactory,
TCyBinaryProtocolFactory,
TJSONProtocolFactory,
)
from thriftpy2.rpc import make_client
from thriftpy2.transport import (TBufferedTransportFactory, TCyBufferedTransportFactory, TCyFramedTransportFactory,
TFramedTransportFactory)
from thriftpy2.transport import (
TBufferedTransportFactory,
TCyBufferedTransportFactory,
TCyFramedTransportFactory,
TFramedTransportFactory,
)
from thriftpy2.utils import deserialize
from httprunner.thrift.data_convertor import json2thrift, thrift2json, thrift2dict
@@ -57,9 +65,17 @@ def get_trans_factory(trans_type):
class ThriftClient(object):
def __init__(self, thrift_file, service_name, ip, port, include_dirs=None, timeout=3000, proto_type=ProtoType.pCyBinary,
trans_type=TransType.tCyBuffered):
def __init__(
self,
thrift_file,
service_name,
ip,
port,
include_dirs=None,
timeout=3000,
proto_type=ProtoType.pCyBinary,
trans_type=TransType.tCyBuffered,
):
self.thrift_file = thrift_file
self.include_dirs = include_dirs
self.service_name = service_name
@@ -69,32 +85,54 @@ class ThriftClient(object):
self.proto_type = proto_type
self.trans_type = trans_type
try:
logger.debug('init thrift module: thrift_file=%s, module_name=%s', thrift_file,
str(self.service_name) + '_thrift')
self.thrift_module = thriftpy2.load(self.thrift_file, module_name=str(self.service_name) + '_thrift',
include_dirs=self.include_dirs)
logger.debug(
"init thrift module: thrift_file=%s, module_name=%s",
thrift_file,
str(self.service_name) + "_thrift",
)
self.thrift_module = thriftpy2.load(
self.thrift_file,
module_name=str(self.service_name) + "_thrift",
include_dirs=self.include_dirs,
)
self.thrift_service_obj = getattr(self.thrift_module, self.service_name)
logger.debug('init thrift client: service_name=%s, ip=%s, port=%s', self.thrift_service_obj, ip, port)
self.client = make_client(self.thrift_service_obj, self.ip, int(self.port), timeout=self.timeout,
proto_factory=get_proto_factory(self.proto_type),
trans_factory=get_trans_factory(self.trans_type))
logger.debug(
"init thrift client: service_name=%s, ip=%s, port=%s",
self.thrift_service_obj,
ip,
port,
)
self.client = make_client(
self.thrift_service_obj,
self.ip,
int(self.port),
timeout=self.timeout,
proto_factory=get_proto_factory(self.proto_type),
trans_factory=get_trans_factory(self.trans_type),
)
except Exception as e:
self.thrift_module = None
self.thrift_service_obj = None
self.client = None
logger.exception('init thrift module and client failed: {}'.format(e))
logger.exception("init thrift module and client failed: {}".format(e))
finally:
thriftpy2.parser.parser.thrift_stack = []
def get_client(self):
return self.client
def send_request(self, request_data, request_method=''):
thrift_req_cls = getattr(self.thrift_service_obj, request_method + '_args').thrift_spec[1][2]
def send_request(self, request_data, request_method=""):
thrift_req_cls = getattr(
self.thrift_service_obj, request_method + "_args"
).thrift_spec[1][2]
request_obj = json2thrift(json.dumps(request_data), thrift_req_cls)
logger.debug('send thrift request: request_method=%s, request_obj=%s', request_method, request_obj)
logger.debug(
"send thrift request: request_method=%s, request_obj=%s",
request_method,
request_obj,
)
response_obj = getattr(self.client, request_method)(request_obj)
logger.debug('thrift response = %s', response_obj)
logger.debug("thrift response = %s", response_obj)
return thrift2dict(response_obj)
def __del__(self):