diff --git a/httprunner/__init__.py b/httprunner/__init__.py index 877d9bc3..6fbfbe8a 100644 --- a/httprunner/__init__.py +++ b/httprunner/__init__.py @@ -7,8 +7,16 @@ from httprunner.runner import HttpRunner from httprunner.step import Step from httprunner.step_request import RunRequest from httprunner.step_testcase import RunTestCase -from httprunner.step_sql_request import RunSqlRequest, StepSqlRequestValidation, StepSqlRequestExtraction -from httprunner.step_thrift_request import RunThriftRequest, StepThriftRequestValidation, StepThriftRequestExtraction +from httprunner.step_sql_request import ( + RunSqlRequest, + StepSqlRequestValidation, + StepSqlRequestExtraction, +) +from httprunner.step_thrift_request import ( + RunThriftRequest, + StepThriftRequestValidation, + StepThriftRequestExtraction, +) __all__ = [ "__version__", diff --git a/httprunner/compat_test.py b/httprunner/compat_test.py index 391133a1..064d3813 100644 --- a/httprunner/compat_test.py +++ b/httprunner/compat_test.py @@ -43,8 +43,7 @@ class TestCompat(unittest.TestCase): "body.data.buildings[0].building_id", ) self.assertEqual( - compat._convert_jmespath("body.users[-1]"), - "body.users[-1]", + compat._convert_jmespath("body.users[-1]"), "body.users[-1]", ) self.assertEqual( compat._convert_jmespath("body.result.WorkNode_-1"), diff --git a/httprunner/database/engine.py b/httprunner/database/engine.py index 8a3cd4c7..7919739c 100644 --- a/httprunner/database/engine.py +++ b/httprunner/database/engine.py @@ -27,7 +27,7 @@ class DBEngine(object): """ for k, v in row.items(): if isinstance(v, datetime.datetime): - row[k] = v.strftime('%Y-%m-%d %H:%M:%S') + row[k] = v.strftime("%Y-%m-%d %H:%M:%S") elif isinstance(v, datetime.date): row[k] = v.strftime("%Y-%m-%d") elif isinstance(v, str): @@ -73,7 +73,6 @@ class DBEngine(object): def update(self, query, commit=True): return self._fetch(query=query, commit=commit) -if __name__ == '__main__': - db = DBEngine( - f"mysql+pymysql://xxxxx:xxxxx@10.0.0.1:3306/dbname?charset=utf8mb4") +if __name__ == "__main__": + db = DBEngine(f"mysql+pymysql://xxxxx:xxxxx@10.0.0.1:3306/dbname?charset=utf8mb4") diff --git a/httprunner/loader_test.py b/httprunner/loader_test.py index 7b09d87b..95449f43 100644 --- a/httprunner/loader_test.py +++ b/httprunner/loader_test.py @@ -97,11 +97,7 @@ class TestLoader(unittest.TestCase): ) def test_load_env_path_not_exist(self): - dot_env_path = os.path.join( - os.getcwd(), - "tests", - "data", - ) + dot_env_path = os.path.join(os.getcwd(), "tests", "data",) env_variables_mapping = loader.load_dot_env_file(dot_env_path) self.assertEqual(env_variables_mapping, {}) diff --git a/httprunner/make.py b/httprunner/make.py index c2b3d3b2..c16dbc43 100644 --- a/httprunner/make.py +++ b/httprunner/make.py @@ -534,8 +534,7 @@ def main_make(tests_paths: List[Text]) -> List[Text]: def init_make_parser(subparsers): """make testcases: parse command line options and run commands.""" parser = subparsers.add_parser( - "make", - help="Convert YAML/JSON testcases to pytest cases.", + "make", help="Convert YAML/JSON testcases to pytest cases.", ) parser.add_argument( "testcase_path", nargs="*", help="Specify YAML/JSON testcase file/folder path" diff --git a/httprunner/make_test.py b/httprunner/make_test.py index f3a80325..9e4b3bc7 100644 --- a/httprunner/make_test.py +++ b/httprunner/make_test.py @@ -73,8 +73,7 @@ from request_methods.request_with_functions_test import ( content, ) self.assertIn( - ".call(RequestWithFunctions)", - content, + ".call(RequestWithFunctions)", content, ) def test_make_testcase_folder(self): @@ -112,8 +111,7 @@ from request_methods.request_with_functions_test import ( ) loader.project_meta = None self.assertEqual( - ensure_file_abs_path_valid(os.getcwd()), - os.getcwd(), + ensure_file_abs_path_valid(os.getcwd()), os.getcwd(), ) loader.project_meta = None self.assertEqual( @@ -124,17 +122,11 @@ from request_methods.request_with_functions_test import ( def test_convert_testcase_path(self): self.assertEqual( convert_testcase_path(os.path.join(self.data_dir, "a-b.c", "2 3.yml")), - ( - os.path.join(self.data_dir, "a_b_c", "T2_3_test.py"), - "T23", - ), + (os.path.join(self.data_dir, "a_b_c", "T2_3_test.py"), "T23",), ) self.assertEqual( convert_testcase_path(os.path.join(self.data_dir, "a-b.c", "中文case.yml")), - ( - os.path.join(self.data_dir, "a_b_c", "中文case_test.py"), - "中文Case", - ), + (os.path.join(self.data_dir, "a_b_c", "中文case_test.py"), "中文Case",), ) def test_make_config_chain_style(self): @@ -153,11 +145,7 @@ from request_methods.request_with_functions_test import ( def test_make_teststep_chain_style(self): step = { "name": "get with params", - "variables": { - "foo1": "bar1", - "foo2": 123, - "sum_v": "${sum_two(1, 2)}", - }, + "variables": {"foo1": "bar1", "foo2": 123, "sum_v": "${sum_two(1, 2)}",}, "request": { "method": "GET", "url": "/get", diff --git a/httprunner/models.py b/httprunner/models.py index 7aec9ad3..689d57d0 100644 --- a/httprunner/models.py +++ b/httprunner/models.py @@ -77,10 +77,11 @@ class TransportEnum(Text, Enum): class TThriftRequest(BaseModel): """ rpc request model""" - method: Text = '' + + method: Text = "" params: Dict = {} thrift_client: Any = None - idl_path: Text = '' # idl local path + idl_path: Text = "" # idl local path timeout: int = 10 # sec transport: TransportEnum = TransportEnum.BUFFERED include_dirs: List[Union[Text, None]] = [] # param of thriftpy2.load @@ -106,6 +107,7 @@ class SqlMethodEnum(Text, Enum): class TSqlRequest(BaseModel): """ sql request model""" + db_config: TConfigDB = TConfigDB() method: SqlMethodEnum = None sql: Text = None diff --git a/httprunner/parser.py b/httprunner/parser.py index e1adb7be..047847dd 100644 --- a/httprunner/parser.py +++ b/httprunner/parser.py @@ -476,9 +476,7 @@ def parse_variables_mapping( return parsed_variables -def parse_parameters( - parameters: Dict, -) -> List[Dict]: +def parse_parameters(parameters: Dict,) -> List[Dict]: """parse parameters and generate cartesian product. Args: diff --git a/httprunner/response.py b/httprunner/response.py index 7d4b738f..5d654654 100644 --- a/httprunner/response.py +++ b/httprunner/response.py @@ -124,20 +124,17 @@ class ResponseObjectBase(object): self.parser = parser self.validation_results: Dict = {} - def extract(self, - extractors: Dict[Text, Text], - variables_mapping: VariablesMapping = None, - ) -> Dict[Text, Any]: + def extract( + self, extractors: Dict[Text, Text], variables_mapping: VariablesMapping = None, + ) -> Dict[Text, Any]: if not extractors: return {} extract_mapping = {} for key, field in extractors.items(): - if '$' in field: + if "$" in field: # field contains variable or function - field = self.parser.parse_data( - field, variables_mapping - ) + field = self.parser.parse_data(field, variables_mapping) field_value = self._search_jmespath(field) extract_mapping[key] = field_value @@ -148,9 +145,7 @@ class ResponseObjectBase(object): raise NotImplementedError("_search_jmespath not override") def validate( - self, - validators: Validators, - variables_mapping: VariablesMapping = None, + self, validators: Validators, variables_mapping: VariablesMapping = None, ): variables_mapping = variables_mapping or {} @@ -173,9 +168,7 @@ class ResponseObjectBase(object): check_item = u_validator["check"] if "$" in check_item: # check_item is variable or function - check_item = self.parser.parse_data( - check_item, variables_mapping - ) + check_item = self.parser.parse_data(check_item, variables_mapping) check_item = parse_string_value(check_item) if check_item and isinstance(check_item, Text): diff --git a/httprunner/response_test.py b/httprunner/response_test.py index 7ab7a6fb..4bf61dac 100644 --- a/httprunner/response_test.py +++ b/httprunner/response_test.py @@ -61,9 +61,6 @@ class TestResponse(unittest.TestCase): def test_validate_functions(self): variables_mapping = {"index": 1} self.resp_obj.validate( - [ - {"eq": ["${get_num(0)}", 0]}, - {"eq": ["${get_num($index)}", 1]}, - ], + [{"eq": ["${get_num(0)}", 0]}, {"eq": ["${get_num($index)}", 1]},], variables_mapping=variables_mapping, ) diff --git a/httprunner/runner.py b/httprunner/runner.py index bee29aca..558a4fa9 100644 --- a/httprunner/runner.py +++ b/httprunner/runner.py @@ -92,7 +92,7 @@ class SessionRunner(object): def with_thrift_client(self, thrift_client) -> "SessionRunner": self.thrift_client = thrift_client - def with_db_engine(self,db_engine): + def with_db_engine(self, db_engine): self.db_engine = db_engine def __parse_config(self, param: Dict = None) -> None: diff --git a/httprunner/step.py b/httprunner/step.py index 974a6457..8c42aea5 100644 --- a/httprunner/step.py +++ b/httprunner/step.py @@ -8,7 +8,11 @@ from httprunner.step_request import ( StepRequestValidation, ) from httprunner.step_testcase import StepRefCase -from httprunner.step_sql_request import RunSqlRequest, StepSqlRequestValidation, StepSqlRequestExtraction +from httprunner.step_sql_request import ( + RunSqlRequest, + StepSqlRequestValidation, + StepSqlRequestExtraction, +) class Step(object): diff --git a/httprunner/step_request.py b/httprunner/step_request.py index 9299cba2..bbb6fa2b 100644 --- a/httprunner/step_request.py +++ b/httprunner/step_request.py @@ -67,10 +67,7 @@ def call_hooks( def run_step_request(runner: HttpRunner, step: TStep) -> StepResult: """run teststep: request""" - step_result = StepResult( - name=step.name, - success=False, - ) + step_result = StepResult(name=step.name, success=False,) start_time = time.time() step.variables = runner.merge_step_variables(step.variables) @@ -82,8 +79,7 @@ def run_step_request(runner: HttpRunner, step: TStep) -> StepResult: request_dict.pop("upload", None) parsed_request_dict = runner.parser.parse_data(request_dict, step.variables) parsed_request_dict["headers"].setdefault( - "HRUN-Request-ID", - f"HRUN-{runner.case_id}-{str(int(time.time() * 1000))[-6:]}", + "HRUN-Request-ID", f"HRUN-{runner.case_id}-{str(int(time.time() * 1000))[-6:]}", ) step.variables["request"] = parsed_request_dict diff --git a/httprunner/step_sql_request.py b/httprunner/step_sql_request.py index d7a29e4b..dd7b98cb 100644 --- a/httprunner/step_sql_request.py +++ b/httprunner/step_sql_request.py @@ -10,7 +10,11 @@ from httprunner.models import IStep, StepResult, TStep from httprunner.models import TSqlRequest, SqlMethodEnum from httprunner.response import SqlResponseObject from httprunner.runner import HttpRunner -from httprunner.step_request import call_hooks, StepRequestExtraction, StepRequestValidation +from httprunner.step_request import ( + call_hooks, + StepRequestExtraction, + StepRequestValidation, +) from httprunner.database.engine import DBEngine from httprunner.exceptions import SqlMethodNotSupport @@ -19,32 +23,42 @@ def run_step_sql_request(runner: HttpRunner, step: TStep) -> StepResult: """run teststep:sql request""" start_time = time.time() - step_result = StepResult( - name=step.name, - success=False, - ) + step_result = StepResult(name=step.name, success=False,) step.variables = runner.merge_step_variables(step.variables) # parse request_dict = step.sql_request.dict() - parsed_request_dict = runner.parser.parse_data( - request_dict, step.variables - ) + parsed_request_dict = runner.parser.parse_data(request_dict, step.variables) config = runner.get_config() - parsed_request_dict["db_config"]["psm"] = parsed_request_dict["db_config"]["psm"] or config.db.psm - parsed_request_dict["db_config"]["user"] = parsed_request_dict["db_config"]["user"] or config.db.user - parsed_request_dict["db_config"]["password"] = parsed_request_dict["db_config"]["password"] or config.db.password - parsed_request_dict["db_config"]["ip"] = parsed_request_dict["db_config"]["ip"] or config.db.ip - parsed_request_dict["db_config"]["port"] = parsed_request_dict["db_config"]["port"] or config.db.port - parsed_request_dict["db_config"]["database"] = parsed_request_dict["db_config"]["database"] or config.db.database + parsed_request_dict["db_config"]["psm"] = ( + parsed_request_dict["db_config"]["psm"] or config.db.psm + ) + parsed_request_dict["db_config"]["user"] = ( + parsed_request_dict["db_config"]["user"] or config.db.user + ) + parsed_request_dict["db_config"]["password"] = ( + parsed_request_dict["db_config"]["password"] or config.db.password + ) + parsed_request_dict["db_config"]["ip"] = ( + parsed_request_dict["db_config"]["ip"] or config.db.ip + ) + parsed_request_dict["db_config"]["port"] = ( + parsed_request_dict["db_config"]["port"] or config.db.port + ) + parsed_request_dict["db_config"]["database"] = ( + parsed_request_dict["db_config"]["database"] or config.db.database + ) if parsed_request_dict["db_config"]["psm"]: - runner.db_engine = DBEngine(f'mysql+pymysql://:@/?charset=utf8mb4&db_psm={parsed_request_dict["psm"]}') + runner.db_engine = DBEngine( + f'mysql+pymysql://:@/?charset=utf8mb4&db_psm={parsed_request_dict["psm"]}' + ) else: runner.db_engine = DBEngine( f'mysql+pymysql://{parsed_request_dict["db_config"]["user"]}:' f'{parsed_request_dict["db_config"]["password"]}@{parsed_request_dict["db_config"]["ip"]}:' f'{parsed_request_dict["db_config"]["port"]}/{parsed_request_dict["db_config"]["database"]}' - f'?charset=utf8mb4') + f"?charset=utf8mb4" + ) # parsed_request_dict["headers"].setdefault( # "HRUN-Request-ID", @@ -57,19 +71,23 @@ def run_step_sql_request(runner: HttpRunner, step: TStep) -> StepResult: logger.info(f"Executing SQL: {parsed_request_dict['sql']}") if step.sql_request.method == SqlMethodEnum.FETCHONE: - sql_resp = runner.db_engine.fetchone(parsed_request_dict['sql']) + sql_resp = runner.db_engine.fetchone(parsed_request_dict["sql"]) elif step.sql_request.method == SqlMethodEnum.INSERT: - sql_resp = runner.db_engine.insert(parsed_request_dict['sql']) + sql_resp = runner.db_engine.insert(parsed_request_dict["sql"]) elif step.sql_request.method == SqlMethodEnum.FETCHMANY: - sql_resp = runner.db_engine.fetchmany(parsed_request_dict['sql'], parsed_request_dict['size']) + sql_resp = runner.db_engine.fetchmany( + parsed_request_dict["sql"], parsed_request_dict["size"] + ) elif step.sql_request.method == SqlMethodEnum.FETCHALL: - sql_resp = runner.db_engine.fetchall(parsed_request_dict['sql']) + sql_resp = runner.db_engine.fetchall(parsed_request_dict["sql"]) elif step.sql_request.method == SqlMethodEnum.UPDATE: - sql_resp = runner.db_engine.update(parsed_request_dict['sql']) + sql_resp = runner.db_engine.update(parsed_request_dict["sql"]) elif step.sql_request.method == SqlMethodEnum.DELETE: - sql_resp = runner.db_engine.delete(parsed_request_dict['sql']) + sql_resp = runner.db_engine.delete(parsed_request_dict["sql"]) else: - raise SqlMethodNotSupport(f"step.sql_request.method {parsed_request_dict['method']} not support") + raise SqlMethodNotSupport( + f"step.sql_request.method {parsed_request_dict['method']} not support" + ) resp_obj = SqlResponseObject(sql_resp, parser=runner.parser) step.variables["sql_response"] = resp_obj @@ -107,9 +125,7 @@ def run_step_sql_request(runner: HttpRunner, step: TStep) -> StepResult: # validate validators = step.validators try: - resp_obj.validate( - validators, variables_mapping - ) + resp_obj.validate(validators, variables_mapping) step_result.success = True except ValidationFailure: log_sql_req_resp_details() @@ -128,7 +144,7 @@ class StepSqlRequestValidation(StepRequestValidation): def __init__(self, step: TStep): self.__step = step super().__init__(step) - + def run(self, runner: HttpRunner): return run_step_sql_request(runner, self.__step) @@ -154,7 +170,9 @@ class RunSqlRequest(IStep): self.__step.variables.update(variables) return self - def with_db_config(self, psm=None, user=None, password=None, ip=None, port=None, database=None): + def with_db_config( + self, psm=None, user=None, password=None, ip=None, port=None, database=None + ): if psm: self.__step.sql_request.db_config.psm = psm if user: @@ -205,7 +223,9 @@ class RunSqlRequest(IStep): self.__step.retry_interval = retry_interval return self - def teardown_hook(self, hook: Text, assign_var_name: Text = None) -> "RunSqlRequest": + def teardown_hook( + self, hook: Text, assign_var_name: Text = None + ) -> "RunSqlRequest": if assign_var_name: self.__step.teardown_hooks.append({assign_var_name: hook}) else: @@ -239,6 +259,8 @@ class RunSqlRequest(IStep): def validate(self) -> StepSqlRequestValidation: return StepSqlRequestValidation(self.__step) - def with_jmespath(self, jmes_path: Text, var_name: Text) -> "StepSqlRequestExtraction": + def with_jmespath( + self, jmes_path: Text, var_name: Text + ) -> "StepSqlRequestExtraction": self.__step.extract[var_name] = jmes_path return StepSqlRequestExtraction(self.__step) diff --git a/httprunner/step_thrift_request.py b/httprunner/step_thrift_request.py index e3d69da8..55bf02ec 100644 --- a/httprunner/step_thrift_request.py +++ b/httprunner/step_thrift_request.py @@ -7,7 +7,11 @@ from httprunner import utils from httprunner.exceptions import ValidationFailure from httprunner.models import IStep, StepResult, TStep, ProtoType, TransType from httprunner.runner import HttpRunner -from httprunner.step_request import call_hooks, StepRequestExtraction, StepRequestValidation +from httprunner.step_request import ( + call_hooks, + StepRequestExtraction, + StepRequestValidation, +) from httprunner.models import TThriftRequest from httprunner.response import ThriftResponseObject @@ -18,29 +22,40 @@ def run_step_thrift_request(runner: HttpRunner, step: TStep) -> StepResult: """run teststep:thrift request""" start_time = time.time() - step_result = StepResult( - name=step.name, - success=False, - ) + step_result = StepResult(name=step.name, success=False,) step.variables = runner.merge_step_variables(step.variables) # parse request_dict = step.thrift_request.dict() - parsed_request_dict = runner.parser.parse_data( - request_dict, step.variables - ) + parsed_request_dict = runner.parser.parse_data(request_dict, step.variables) config = runner.get_config() parsed_request_dict["psm"] = parsed_request_dict["psm"] or config.thrift.psm parsed_request_dict["env"] = parsed_request_dict["env"] or config.thrift.env - parsed_request_dict["cluster"] = parsed_request_dict["cluster"] or config.thrift.cluster - parsed_request_dict["idl_path"] = parsed_request_dict["idl_path"] or config.thrift.idl_path - parsed_request_dict["include_dirs"] = parsed_request_dict["include_dirs"] or config.thrift.include_dirs - parsed_request_dict["method"] = parsed_request_dict["method"] or config.thrift.method - parsed_request_dict["service_name"] = parsed_request_dict["service_name"] or config.thrift.service_name + parsed_request_dict["cluster"] = ( + parsed_request_dict["cluster"] or config.thrift.cluster + ) + parsed_request_dict["idl_path"] = ( + parsed_request_dict["idl_path"] or config.thrift.idl_path + ) + parsed_request_dict["include_dirs"] = ( + parsed_request_dict["include_dirs"] or config.thrift.include_dirs + ) + parsed_request_dict["method"] = ( + parsed_request_dict["method"] or config.thrift.method + ) + parsed_request_dict["service_name"] = ( + parsed_request_dict["service_name"] or config.thrift.service_name + ) parsed_request_dict["ip"] = parsed_request_dict["ip"] or config.thrift.ip parsed_request_dict["port"] = parsed_request_dict["port"] or config.thrift.port - parsed_request_dict["proto_type"] = parsed_request_dict["proto_type"] or config.thrift.proto_type - parsed_request_dict["trans_port"] = parsed_request_dict["trans_type"] or config.thrift.trans_type - parsed_request_dict["timeout"] = parsed_request_dict["timeout"] or config.thrift.timeout + parsed_request_dict["proto_type"] = ( + parsed_request_dict["proto_type"] or config.thrift.proto_type + ) + parsed_request_dict["trans_port"] = ( + parsed_request_dict["trans_type"] or config.thrift.trans_type + ) + parsed_request_dict["timeout"] = ( + parsed_request_dict["timeout"] or config.thrift.timeout + ) parsed_request_dict["thrift_client"] = parsed_request_dict["thrift_client"] # parsed_request_dict["headers"].setdefault( @@ -53,17 +68,24 @@ def run_step_thrift_request(runner: HttpRunner, step: TStep) -> StepResult: runner.thrift_client = parsed_request_dict["thrift_client"] if not runner.thrift_client: - runner.thrift_client = ThriftClient(parsed_request_dict["idl_path"], parsed_request_dict["service_name"], - parsed_request_dict["ip"], parsed_request_dict["port"], - parsed_request_dict["timeout"], parsed_request_dict["proto_type"], - parsed_request_dict["trans_port"]) + runner.thrift_client = ThriftClient( + parsed_request_dict["idl_path"], + parsed_request_dict["service_name"], + parsed_request_dict["ip"], + parsed_request_dict["port"], + parsed_request_dict["timeout"], + parsed_request_dict["proto_type"], + parsed_request_dict["trans_port"], + ) # setup hooks if step.setup_hooks: call_hooks(runner, step.setup_hooks, step.variables, "setup request") # thrift request - resp = runner.thrift_client.send_request(parsed_request_dict["params"], parsed_request_dict["method"]) + resp = runner.thrift_client.send_request( + parsed_request_dict["params"], parsed_request_dict["method"] + ) resp_obj = ThriftResponseObject(resp, parser=runner.parser) step.variables["thrift_response"] = resp_obj @@ -72,7 +94,9 @@ def run_step_thrift_request(runner: HttpRunner, step: TStep) -> StepResult: call_hooks(runner, step.teardown_hooks, step.variables, "teardown request") def log_thrift_req_resp_details(): - err_msg = "\n{} THRIFT DETAILED REQUEST & RESPONSE {}\n".format("*" * 32, "*" * 32) + err_msg = "\n{} THRIFT DETAILED REQUEST & RESPONSE {}\n".format( + "*" * 32, "*" * 32 + ) # log request err_msg += "====== thrift request details ======\n" @@ -101,9 +125,7 @@ def run_step_thrift_request(runner: HttpRunner, step: TStep) -> StepResult: # validate validators = step.validators try: - resp_obj.validate( - validators, variables_mapping - ) + resp_obj.validate(validators, variables_mapping) step_result.success = True except ValidationFailure: log_thrift_req_resp_details() @@ -153,7 +175,9 @@ class RunThriftRequest(IStep): self.__step.retry_interval = retry_interval return self - def teardown_hook(self, hook: Text, assign_var_name: Text = None) -> "RunThriftRequest": + def teardown_hook( + self, hook: Text, assign_var_name: Text = None + ) -> "RunThriftRequest": if assign_var_name: self.__step.teardown_hooks.append({assign_var_name: hook}) else: @@ -182,11 +206,13 @@ class RunThriftRequest(IStep): self.__step.thrift_request.include_dirs = [idl_root_path] return self - def with_thrift_client(self, thrift_client: Union["ThriftClient", str]) -> "RunThriftRequest": + def with_thrift_client( + self, thrift_client: Union["ThriftClient", str] + ) -> "RunThriftRequest": self.__step.thrift_request.thrift_client = thrift_client return self - def with_ip(self,ip: str) -> "RunThriftRequest": + def with_ip(self, ip: str) -> "RunThriftRequest": self.__step.thrift_request.ip = ip return self @@ -194,11 +220,11 @@ class RunThriftRequest(IStep): self.__step.thrift_request.port = port return self - def with_proto_type(self,proto_type:ProtoType) -> "RunThriftRequest": + def with_proto_type(self, proto_type: ProtoType) -> "RunThriftRequest": self.__step.thrift_request.proto_type = proto_type return self - def with_trans_type(self,trans_type:TransType) -> "RunThriftRequest": + def with_trans_type(self, trans_type: TransType) -> "RunThriftRequest": self.__step.thrift_request.proto_type = trans_type return self @@ -220,6 +246,8 @@ class RunThriftRequest(IStep): def validate(self) -> StepThriftRequestValidation: return StepThriftRequestValidation(self.__step) - def with_jmespath(self, jmes_path: Text, var_name: Text) -> "StepThriftRequestExtraction": + def with_jmespath( + self, jmes_path: Text, var_name: Text + ) -> "StepThriftRequestExtraction": self.__step.extract[var_name] = jmes_path return StepThriftRequestExtraction(self.__step) diff --git a/httprunner/thrift/data_convertor.py b/httprunner/thrift/data_convertor.py index f54ab3cb..c25e034b 100644 --- a/httprunner/thrift/data_convertor.py +++ b/httprunner/thrift/data_convertor.py @@ -19,21 +19,21 @@ text_characters = "".join(map(chr, range(32, 127))) + "\n\r\t\b" _null_trans = str.maketrans("", "") ESCAPE = re.compile(r'[\x00-\x1f\\"\b\f\n\r\t]') ESCAPE_ASCII = re.compile(r'([\\"]|[^\ -~])') -HAS_UTF8 = re.compile(r'[\x80-\xff]') +HAS_UTF8 = re.compile(r"[\x80-\xff]") ESCAPE_DCT = { - '\\': '\\\\', + "\\": "\\\\", '"': '\\"', - '\b': '\\b', - '\f': '\\f', - '\n': '\\n', - '\r': '\\r', - '\t': '\\t', + "\b": "\\b", + "\f": "\\f", + "\n": "\\n", + "\r": "\\r", + "\t": "\\t", } for i in range(0x20): - ESCAPE_DCT.setdefault(chr(i), '\\u{0:04x}'.format(i)) + ESCAPE_DCT.setdefault(chr(i), "\\u{0:04x}".format(i)) # ESCAPE_DCT.setdefault(chr(i), '\\u%04x' % (i,)) -INFINITY = float('inf') +INFINITY = float("inf") FLOAT_REPR = repr @@ -66,7 +66,7 @@ def unicode_2_utf8_keep_native(para): elif type(para) is tuple: return tuple(unicode_2_utf8_keep_native(list(para))) elif type(para) is str: - return para.encode('utf-8') + return para.encode("utf-8") else: logging.debug("type========", type(para)) # if issubclass(type(para), dict): @@ -93,7 +93,7 @@ def py_encode_basestring_ascii(s): """ if isinstance(s, str) and HAS_UTF8.search(s) is not None: - s = s.decode('utf-8') + s = s.decode("utf-8") def replace(match): s = match.group(0) @@ -102,27 +102,25 @@ def py_encode_basestring_ascii(s): except KeyError: n = ord(s) if n < 0x10000: - return '\\u{0:04x}'.format(n) + return "\\u{0:04x}".format(n) # return '\\u%04x' % (n,) else: # surrogate pair n -= 0x10000 - s1 = 0xd800 | ((n >> 10) & 0x3ff) - s2 = 0xdc00 | (n & 0x3ff) - return '\\u{0:04x}\\u{1:04x}'.format(s1, s2) + s1 = 0xD800 | ((n >> 10) & 0x3FF) + s2 = 0xDC00 | (n & 0x3FF) + return "\\u{0:04x}\\u{1:04x}".format(s1, s2) # return '\\u%04x\\u%04x' % (s1, s2) return '"' + str(ESCAPE_ASCII.sub(replace, s)) + '"' -encode_basestring_ascii = ( - c_encode_basestring_ascii or py_encode_basestring_ascii) +encode_basestring_ascii = c_encode_basestring_ascii or py_encode_basestring_ascii class ThriftJSONDecoder(json.JSONDecoder): - def __init__(self, *args, **kwargs): - self._thrift_class = kwargs.pop('thrift_class') + self._thrift_class = kwargs.pop("thrift_class") super(ThriftJSONDecoder, self).__init__(*args, **kwargs) def decode(self, json_str): @@ -130,9 +128,12 @@ class ThriftJSONDecoder(json.JSONDecoder): dct = json_str else: dct = super(ThriftJSONDecoder, self).decode(json_str) - return self._convert(dct, TType.STRUCT, - # (self._thrift_class, self._thrift_class.thrift_spec)) - self._thrift_class) + return self._convert( + dct, + TType.STRUCT, + # (self._thrift_class, self._thrift_class.thrift_spec)) + self._thrift_class, + ) def _convert(self, val, ttype, ttype_info): if ttype == TType.STRUCT: @@ -156,7 +157,9 @@ class ThriftJSONDecoder(json.JSONDecoder): if val is None or field_name not in val: continue - converted_val = self._convert(val[field_name], field_ttype, field_ttype_info) + converted_val = self._convert( + val[field_name], field_ttype, field_ttype_info + ) setattr(ret, field_name, converted_val) elif ttype == TType.LIST: if type(ttype_info) != tuple: # 说明是基础类型了, 无法在细分 @@ -174,7 +177,9 @@ class ThriftJSONDecoder(json.JSONDecoder): else: (element_ttype, element_ttype_info) = ttype_info if val is not None: - ret = set([self._convert(x, element_ttype, element_ttype_info) for x in val]) + ret = set( + [self._convert(x, element_ttype, element_ttype_info) for x in val] + ) else: ret = None @@ -193,8 +198,15 @@ class ThriftJSONDecoder(json.JSONDecoder): val_ttype, val_ttype_info = ttype_info[1] if val is not None: - ret = dict([(self._convert(k, key_ttype, key_ttype_info), - self._convert(v, val_ttype, val_ttype_info)) for (k, v) in val.items()]) + ret = dict( + [ + ( + self._convert(k, key_ttype, key_ttype_info), + self._convert(v, val_ttype, val_ttype_info), + ) + for (k, v) in val.items() + ] + ) else: ret = None elif ttype == TType.STRING: @@ -228,13 +240,15 @@ class ThriftJSONDecoder(json.JSONDecoder): else: ret = None else: - raise TypeError('Unrecognized thrift field type: %s' % ttype) + raise TypeError("Unrecognized thrift field type: %s" % ttype) return ret def json2thrift(json_str, thrift_class): logging.debug(json_str) - return json.loads(json_str, cls=ThriftJSONDecoder, thrift_class=thrift_class, strict=False) + return json.loads( + json_str, cls=ThriftJSONDecoder, thrift_class=thrift_class, strict=False + ) def dumper(obj): @@ -245,14 +259,33 @@ def dumper(obj): class MyJSONEncoder(json.JSONEncoder): - def __init__(self, skipkeys=False, ensure_ascii=True, check_circular=True, - allow_nan=True, indent=None, separators=None, - encoding='utf-8', default=None, sort_keys=False, **kw): - super(MyJSONEncoder, self).__init__(skipkeys=skipkeys, ensure_ascii=ensure_ascii, - check_circular=check_circular, allow_nan=allow_nan, indent=indent, - separators=separators, encoding=encoding, default=default, - sort_keys=sort_keys) - self.skip_nonutf8_value = kw.get('skip_nonutf8_value', False) # 默认不skip忽略非utf-8编码的字段 + def __init__( + self, + skipkeys=False, + ensure_ascii=True, + check_circular=True, + allow_nan=True, + indent=None, + separators=None, + encoding="utf-8", + default=None, + sort_keys=False, + **kw + ): + super(MyJSONEncoder, self).__init__( + skipkeys=skipkeys, + ensure_ascii=ensure_ascii, + check_circular=check_circular, + allow_nan=allow_nan, + indent=indent, + separators=separators, + encoding=encoding, + default=default, + sort_keys=sort_keys, + ) + self.skip_nonutf8_value = kw.get( + "skip_nonutf8_value", False + ) # 默认不skip忽略非utf-8编码的字段 def encode(self, o): """Return a JSON string representation of a Python data structure. @@ -266,8 +299,7 @@ class MyJSONEncoder(json.JSONEncoder): if isinstance(o, str): _encoding = self.encoding - if (_encoding is not None - and not (_encoding == 'utf-8')): + if _encoding is not None and not (_encoding == "utf-8"): o = o.decode(_encoding) if self.ensure_ascii: return encode_basestring_ascii(o) @@ -288,10 +320,10 @@ class MyJSONEncoder(json.JSONEncoder): tmp_chunks.append(unicode_2_utf8_keep_native(chunk)) except Exception as err: logging.debug(traceback.format_exc()) - return ''.join(tmp_chunks) + return "".join(tmp_chunks) # 保留老的逻辑, /usr/lib/python2.7/package/json/__init__.py dumps接口 - return ''.join(chunks) + return "".join(chunks) class ThriftJSONEncoder(json.JSONEncoder): @@ -299,13 +331,32 @@ class ThriftJSONEncoder(json.JSONEncoder): add by braver(Braver@bytedance.com) """ - def __init__(self, skipkeys=False, ensure_ascii=True, check_circular=True, - allow_nan=True, indent=None, separators=None, default=None, sort_keys=False, **kw): + def __init__( + self, + skipkeys=False, + ensure_ascii=True, + check_circular=True, + allow_nan=True, + indent=None, + separators=None, + default=None, + sort_keys=False, + **kw + ): - super(ThriftJSONEncoder, self).__init__(skipkeys=skipkeys, ensure_ascii=ensure_ascii, - check_circular=check_circular, allow_nan=allow_nan, indent=indent, - separators=separators, default=default, sort_keys=sort_keys) - self.skip_nonutf8_value = kw.get('skip_nonutf8_value', False) # 默认不skip忽略非utf-8编码的字段 + super(ThriftJSONEncoder, self).__init__( + skipkeys=skipkeys, + ensure_ascii=ensure_ascii, + check_circular=check_circular, + allow_nan=allow_nan, + indent=indent, + separators=separators, + default=default, + sort_keys=sort_keys, + ) + self.skip_nonutf8_value = kw.get( + "skip_nonutf8_value", False + ) # 默认不skip忽略非utf-8编码的字段 def encode(self, o): """Return a JSON string representation of a Python data structure. @@ -318,8 +369,7 @@ class ThriftJSONEncoder(json.JSONEncoder): if isinstance(o, str): if isinstance(o, str): _encoding = self.encoding - if (_encoding is not None - and not (_encoding == 'utf-8')): + if _encoding is not None and not (_encoding == "utf-8"): o = o.decode(_encoding) if self.ensure_ascii: return encode_basestring_ascii(o) @@ -340,18 +390,18 @@ class ThriftJSONEncoder(json.JSONEncoder): tmp_chunks.append(unicode_2_utf8_keep_native(chunk)) except Exception as err: logging.debug(traceback.format_exc()) - return ''.join(tmp_chunks) + return "".join(tmp_chunks) # 保留老的逻辑, /usr/lib/python2.7/package/json/__init__.py dumps接口 - return ''.join(chunks) + return "".join(chunks) def default(self, o): if isinstance(o, bytes): - return str(o, encoding='utf-8') - if not hasattr(o, 'thrift_spec'): + return str(o, encoding="utf-8") + if not hasattr(o, "thrift_spec"): return super(ThriftJSONEncoder, self).default(o) - spec = getattr(o, 'thrift_spec') + spec = getattr(o, "thrift_spec") ret = {} for tag, field in spec.items(): if field is None: @@ -370,30 +420,42 @@ class ThriftJSONEncoder(json.JSONEncoder): val = list(val) # 统一转成数组(list/set) is_need_binary_bs64 = False if type(field_ttype_info) != tuple: # 基础类型 - if field_ttype_info in [TType.BYTE] and type(val[0]) in [str] and not istext( - val[0]): + if ( + field_ttype_info in [TType.BYTE] + and type(val[0]) in [str] + and not istext(val[0]) + ): is_need_binary_bs64 = True if is_need_binary_bs64: for index, item in enumerate(val): if item and type(item) in [str] and not istext(item): - val[index] = base64.b64encode(item) # 判断为二进制字符串, 需要进行base64编码 - if field_type in [TType.BYTE] and type(val) in [str]: # 说明是string(明文string或者binary) + val[index] = base64.b64encode( + item + ) # 判断为二进制字符串, 需要进行base64编码 + if field_type in [TType.BYTE] and type(val) in [ + str + ]: # 说明是string(明文string或者binary) # 需要对二进制字节字符串字段进行base64编码, 将二进制字节串字段->ascii字符编码的base64编码明文串 if val and not istext(val): # 说明是该字段非空且为binary string - print('4' * 100, val) - val = base64.b64encode(val.encode('utf-8')) + print("4" * 100, val) + val = base64.b64encode(val.encode("utf-8")) # val = base64.b64encode(val) # 进行base64编码处理, 不然该字段序列化为json时会报错 # if val != default: ret[field_name] = val - if 'request_id' in o.__dict__: - ret['request_id'] = o.__dict__['request_id'] - if 'rpc_latency' in o.__dict__: - ret['rpc_latency'] = o.__dict__['rpc_latency'] + if "request_id" in o.__dict__: + ret["request_id"] = o.__dict__["request_id"] + if "rpc_latency" in o.__dict__: + ret["rpc_latency"] = o.__dict__["rpc_latency"] return ret def thrift2json(obj, skip_nonutf8_value=False): - return json.dumps(obj, cls=ThriftJSONEncoder, ensure_ascii=False, skip_nonutf8_value=skip_nonutf8_value) + return json.dumps( + obj, + cls=ThriftJSONEncoder, + ensure_ascii=False, + skip_nonutf8_value=skip_nonutf8_value, + ) def thrift2dict(obj): @@ -403,8 +465,11 @@ def thrift2dict(obj): dict2thrift = json2thrift -if __name__ == '__main__': +if __name__ == "__main__": print(istext("Всего за {$price$}, а доставка - бесплатно!")) - print(istext(b'\xe4\xb8\xad\xe6\x96\x87')) - print(istext( - '{"web_uri":"ad-site-i18n-sg/202103185d0d723d88b7f642452dac73","height":336,"width":336,"file_name":""}')) + print(istext(b"\xe4\xb8\xad\xe6\x96\x87")) + print( + istext( + '{"web_uri":"ad-site-i18n-sg/202103185d0d723d88b7f642452dac73","height":336,"width":336,"file_name":""}' + ) + ) diff --git a/httprunner/thrift/thrift_client.py b/httprunner/thrift/thrift_client.py index 72b5c98d..59cd2d5f 100644 --- a/httprunner/thrift/thrift_client.py +++ b/httprunner/thrift/thrift_client.py @@ -5,11 +5,19 @@ import json from loguru import logger import thriftpy2 -from thriftpy2.protocol import (TBinaryProtocolFactory, TCompactProtocolFactory, TCyBinaryProtocolFactory, - TJSONProtocolFactory) +from thriftpy2.protocol import ( + TBinaryProtocolFactory, + TCompactProtocolFactory, + TCyBinaryProtocolFactory, + TJSONProtocolFactory, +) from thriftpy2.rpc import make_client -from thriftpy2.transport import (TBufferedTransportFactory, TCyBufferedTransportFactory, TCyFramedTransportFactory, - TFramedTransportFactory) +from thriftpy2.transport import ( + TBufferedTransportFactory, + TCyBufferedTransportFactory, + TCyFramedTransportFactory, + TFramedTransportFactory, +) from thriftpy2.utils import deserialize from httprunner.thrift.data_convertor import json2thrift, thrift2json, thrift2dict @@ -57,9 +65,17 @@ def get_trans_factory(trans_type): class ThriftClient(object): - - def __init__(self, thrift_file, service_name, ip, port, include_dirs=None, timeout=3000, proto_type=ProtoType.pCyBinary, - trans_type=TransType.tCyBuffered): + def __init__( + self, + thrift_file, + service_name, + ip, + port, + include_dirs=None, + timeout=3000, + proto_type=ProtoType.pCyBinary, + trans_type=TransType.tCyBuffered, + ): self.thrift_file = thrift_file self.include_dirs = include_dirs self.service_name = service_name @@ -69,32 +85,54 @@ class ThriftClient(object): self.proto_type = proto_type self.trans_type = trans_type try: - logger.debug('init thrift module: thrift_file=%s, module_name=%s', thrift_file, - str(self.service_name) + '_thrift') - self.thrift_module = thriftpy2.load(self.thrift_file, module_name=str(self.service_name) + '_thrift', - include_dirs=self.include_dirs) + logger.debug( + "init thrift module: thrift_file=%s, module_name=%s", + thrift_file, + str(self.service_name) + "_thrift", + ) + self.thrift_module = thriftpy2.load( + self.thrift_file, + module_name=str(self.service_name) + "_thrift", + include_dirs=self.include_dirs, + ) self.thrift_service_obj = getattr(self.thrift_module, self.service_name) - logger.debug('init thrift client: service_name=%s, ip=%s, port=%s', self.thrift_service_obj, ip, port) - self.client = make_client(self.thrift_service_obj, self.ip, int(self.port), timeout=self.timeout, - proto_factory=get_proto_factory(self.proto_type), - trans_factory=get_trans_factory(self.trans_type)) + logger.debug( + "init thrift client: service_name=%s, ip=%s, port=%s", + self.thrift_service_obj, + ip, + port, + ) + self.client = make_client( + self.thrift_service_obj, + self.ip, + int(self.port), + timeout=self.timeout, + proto_factory=get_proto_factory(self.proto_type), + trans_factory=get_trans_factory(self.trans_type), + ) except Exception as e: self.thrift_module = None self.thrift_service_obj = None self.client = None - logger.exception('init thrift module and client failed: {}'.format(e)) + logger.exception("init thrift module and client failed: {}".format(e)) finally: thriftpy2.parser.parser.thrift_stack = [] def get_client(self): return self.client - def send_request(self, request_data, request_method=''): - thrift_req_cls = getattr(self.thrift_service_obj, request_method + '_args').thrift_spec[1][2] + def send_request(self, request_data, request_method=""): + thrift_req_cls = getattr( + self.thrift_service_obj, request_method + "_args" + ).thrift_spec[1][2] request_obj = json2thrift(json.dumps(request_data), thrift_req_cls) - logger.debug('send thrift request: request_method=%s, request_obj=%s', request_method, request_obj) + logger.debug( + "send thrift request: request_method=%s, request_obj=%s", + request_method, + request_obj, + ) response_obj = getattr(self.client, request_method)(request_obj) - logger.debug('thrift response = %s', response_obj) + logger.debug("thrift response = %s", response_obj) return thrift2dict(response_obj) def __del__(self):