From d1a8835b1e21dfbbe95cf89ff960b953a47b75cb Mon Sep 17 00:00:00 2001 From: "duanchao.bill" Date: Tue, 26 Apr 2022 17:59:50 +0800 Subject: [PATCH] add:sql and thrift as step --- httprunner/step_thrift_request.py | 225 +++++++++++++++ httprunner/thrift/data_convertor.py | 410 ++++++++++++++++++++++++++++ httprunner/thrift/thrift_client.py | 101 +++++++ 3 files changed, 736 insertions(+) create mode 100644 httprunner/step_thrift_request.py create mode 100644 httprunner/thrift/data_convertor.py create mode 100644 httprunner/thrift/thrift_client.py diff --git a/httprunner/step_thrift_request.py b/httprunner/step_thrift_request.py new file mode 100644 index 00000000..e3d69da8 --- /dev/null +++ b/httprunner/step_thrift_request.py @@ -0,0 +1,225 @@ +# -*- coding: utf-8 -*- +import time +from typing import Text, Union +from loguru import logger + +from httprunner import utils +from httprunner.exceptions import ValidationFailure +from httprunner.models import IStep, StepResult, TStep, ProtoType, TransType +from httprunner.runner import HttpRunner +from httprunner.step_request import call_hooks, StepRequestExtraction, StepRequestValidation +from httprunner.models import TThriftRequest +from httprunner.response import ThriftResponseObject + +from httprunner.thrift.thrift_client import ThriftClient + + +def run_step_thrift_request(runner: HttpRunner, step: TStep) -> StepResult: + """run teststep:thrift request""" + start_time = time.time() + + step_result = StepResult( + name=step.name, + success=False, + ) + step.variables = runner.merge_step_variables(step.variables) + # parse + request_dict = step.thrift_request.dict() + parsed_request_dict = runner.parser.parse_data( + request_dict, step.variables + ) + config = runner.get_config() + parsed_request_dict["psm"] = parsed_request_dict["psm"] or config.thrift.psm + parsed_request_dict["env"] = parsed_request_dict["env"] or config.thrift.env + parsed_request_dict["cluster"] = parsed_request_dict["cluster"] or config.thrift.cluster + parsed_request_dict["idl_path"] = parsed_request_dict["idl_path"] or config.thrift.idl_path + parsed_request_dict["include_dirs"] = parsed_request_dict["include_dirs"] or config.thrift.include_dirs + parsed_request_dict["method"] = parsed_request_dict["method"] or config.thrift.method + parsed_request_dict["service_name"] = parsed_request_dict["service_name"] or config.thrift.service_name + parsed_request_dict["ip"] = parsed_request_dict["ip"] or config.thrift.ip + parsed_request_dict["port"] = parsed_request_dict["port"] or config.thrift.port + parsed_request_dict["proto_type"] = parsed_request_dict["proto_type"] or config.thrift.proto_type + parsed_request_dict["trans_port"] = parsed_request_dict["trans_type"] or config.thrift.trans_type + parsed_request_dict["timeout"] = parsed_request_dict["timeout"] or config.thrift.timeout + parsed_request_dict["thrift_client"] = parsed_request_dict["thrift_client"] + + # parsed_request_dict["headers"].setdefault( + # "HRUN-Request-ID", + # f"HRUN-{self.__case_id}-{str(int(time.time() * 1000))[-6:]}", + # ) + step.variables["thrift_request"] = parsed_request_dict + + psm = parsed_request_dict["psm"] + + runner.thrift_client = parsed_request_dict["thrift_client"] + if not runner.thrift_client: + runner.thrift_client = ThriftClient(parsed_request_dict["idl_path"], parsed_request_dict["service_name"], + parsed_request_dict["ip"], parsed_request_dict["port"], + parsed_request_dict["timeout"], parsed_request_dict["proto_type"], + parsed_request_dict["trans_port"]) + + # setup hooks + if step.setup_hooks: + call_hooks(runner, step.setup_hooks, step.variables, "setup request") + + # thrift request + resp = runner.thrift_client.send_request(parsed_request_dict["params"], parsed_request_dict["method"]) + resp_obj = ThriftResponseObject(resp, parser=runner.parser) + step.variables["thrift_response"] = resp_obj + + # teardown hooks + if step.teardown_hooks: + call_hooks(runner, step.teardown_hooks, step.variables, "teardown request") + + def log_thrift_req_resp_details(): + err_msg = "\n{} THRIFT DETAILED REQUEST & RESPONSE {}\n".format("*" * 32, "*" * 32) + + # log request + err_msg += "====== thrift request details ======\n" + err_msg += f"psm: {psm}\n" + for k, v in parsed_request_dict.items(): + v = utils.omit_long_data(v) + err_msg += f"{k}: {repr(v)}\n" + + err_msg += "\n" + + # log response + err_msg += "====== thrift response details ======\n" + for k, v in resp.items(): + v = utils.omit_long_data(v) + err_msg += f"{k}: {repr(v)}\n" + logger.error(err_msg) + + # extract + extractors = step.extract + extract_mapping = resp_obj.extract(extractors) + step_result.export_vars = extract_mapping + + variables_mapping = step.variables + variables_mapping.update(extract_mapping) + + # validate + validators = step.validators + try: + resp_obj.validate( + validators, variables_mapping + ) + step_result.success = True + except ValidationFailure: + log_thrift_req_resp_details() + raise + finally: + session_data = runner.session.data + session_data.success = step_result.success + session_data.validators = resp_obj.validation_results + # save step data + step_result.data = session_data + step_result.elapsed = time.time() - start_time + return step_result + + +class StepThriftRequestValidation(StepRequestValidation): + def __init__(self, step: TStep): + self.__step = step + super().__init__(step) + + def run(self, runner: HttpRunner): + return run_step_thrift_request(runner, self.__step) + + +class StepThriftRequestExtraction(StepRequestExtraction): + def __init__(self, step: TStep): + self.__step = step + super().__init__(step) + + def run(self, runner: HttpRunner): + return run_step_thrift_request(runner, self.__step) + + def validate(self) -> StepThriftRequestValidation: + return StepThriftRequestValidation(self.__step) + + +class RunThriftRequest(IStep): + def __init__(self, name: Text): + self.__step = TStep(name=name) + self.__step.thrift_request = TThriftRequest() + + def with_variables(self, **variables) -> "RunThriftRequest": + self.__step.variables.update(variables) + return self + + def with_retry(self, retry_times, retry_interval) -> "RunThriftRequest": + self.__step.retry_times = retry_times + self.__step.retry_interval = retry_interval + return self + + def teardown_hook(self, hook: Text, assign_var_name: Text = None) -> "RunThriftRequest": + if assign_var_name: + self.__step.teardown_hooks.append({assign_var_name: hook}) + else: + self.__step.teardown_hooks.append(hook) + + return self + + def setup_hook(self, hook: Text, assign_var_name: Text = None) -> "RunTestCase": + if assign_var_name: + self.__step.setup_hooks.append({assign_var_name: hook}) + else: + self.__step.setup_hooks.append(hook) + + return self + + def with_params(self, **params) -> "RunThriftRequest": + self.__step.thrift_request.params.update(params) + return self + + def with_method(self, method) -> "RunThriftRequest": + self.__step.thrift_request.method = method + return self + + def with_idl_path(self, idl_path, idl_root_path) -> "RunThriftRequest": + self.__step.thrift_request.idl_path = idl_path + self.__step.thrift_request.include_dirs = [idl_root_path] + return self + + def with_thrift_client(self, thrift_client: Union["ThriftClient", str]) -> "RunThriftRequest": + self.__step.thrift_request.thrift_client = thrift_client + return self + + def with_ip(self,ip: str) -> "RunThriftRequest": + self.__step.thrift_request.ip = ip + return self + + def with_port(self, port: int) -> "RunThriftRequest": + self.__step.thrift_request.port = port + return self + + def with_proto_type(self,proto_type:ProtoType) -> "RunThriftRequest": + self.__step.thrift_request.proto_type = proto_type + return self + + def with_trans_type(self,trans_type:TransType) -> "RunThriftRequest": + self.__step.thrift_request.proto_type = trans_type + return self + + def struct(self) -> TStep: + return self.__step + + def name(self) -> Text: + return self.__step.name + + def type(self) -> Text: + return f"thrift-request-{self.__step.thrift_request.psm}-{self.__step.thrift_request.method}" + + def run(self, runner) -> StepResult: + return run_step_thrift_request(runner, self.__step) + + def extract(self) -> StepThriftRequestExtraction: + return StepThriftRequestExtraction(self.__step) + + def validate(self) -> StepThriftRequestValidation: + return StepThriftRequestValidation(self.__step) + + def with_jmespath(self, jmes_path: Text, var_name: Text) -> "StepThriftRequestExtraction": + self.__step.extract[var_name] = jmes_path + return StepThriftRequestExtraction(self.__step) diff --git a/httprunner/thrift/data_convertor.py b/httprunner/thrift/data_convertor.py new file mode 100644 index 00000000..f54ab3cb --- /dev/null +++ b/httprunner/thrift/data_convertor.py @@ -0,0 +1,410 @@ +# -*- coding: utf-8 -*- + +from __future__ import division + +import json +import traceback +import re +import logging +import base64 + +from thrift.Thrift import TType + +try: + from _json import encode_basestring_ascii as c_encode_basestring_ascii +except ImportError: + c_encode_basestring_ascii = None + +text_characters = "".join(map(chr, range(32, 127))) + "\n\r\t\b" +_null_trans = str.maketrans("", "") +ESCAPE = re.compile(r'[\x00-\x1f\\"\b\f\n\r\t]') +ESCAPE_ASCII = re.compile(r'([\\"]|[^\ -~])') +HAS_UTF8 = re.compile(r'[\x80-\xff]') +ESCAPE_DCT = { + '\\': '\\\\', + '"': '\\"', + '\b': '\\b', + '\f': '\\f', + '\n': '\\n', + '\r': '\\r', + '\t': '\\t', +} +for i in range(0x20): + ESCAPE_DCT.setdefault(chr(i), '\\u{0:04x}'.format(i)) + # ESCAPE_DCT.setdefault(chr(i), '\\u%04x' % (i,)) + +INFINITY = float('inf') +FLOAT_REPR = repr + + +def istext(s_input): + """ + 既然我们要判断这串内容是不是可以做为Json的value,那为什么不放下试试呢? + :param s_input: + :return: + """ + return not isinstance(s_input, bytes) + + +def unicode_2_utf8_keep_native(para): + # if type(para) is str: + # return ''.join(filter(lambda x: not str.isalpha(x), para)) + if type(para) is str: + return para + + if type(para) is list: + for i in range(len(para)): + para[i] = unicode_2_utf8_keep_native(para[i]) + return para + elif type(para) is dict: + newpara = {} + for (key, value) in para.items(): + key = unicode_2_utf8_keep_native(key) + value = unicode_2_utf8_keep_native(value) + newpara[key] = value + return newpara + elif type(para) is tuple: + return tuple(unicode_2_utf8_keep_native(list(para))) + elif type(para) is str: + return para.encode('utf-8') + else: + logging.debug("type========", type(para)) + # if issubclass(type(para), dict): + if isinstance(para, dict): + logging.debug("type ************in dict: %s" % (type(para))) + return unicode_2_utf8_keep_native(dict(para)) + else: + return para + + +def encode_basestring(s): + """Return a JSON representation of a Python string + + """ + + def replace(match): + return ESCAPE_DCT[match.group(0)] + + return '"' + ESCAPE.sub(replace, s) + '"' + + +def py_encode_basestring_ascii(s): + """Return an ASCII-only JSON representation of a Python string + + """ + if isinstance(s, str) and HAS_UTF8.search(s) is not None: + s = s.decode('utf-8') + + def replace(match): + s = match.group(0) + try: + return ESCAPE_DCT[s] + except KeyError: + n = ord(s) + if n < 0x10000: + return '\\u{0:04x}'.format(n) + # return '\\u%04x' % (n,) + else: + # surrogate pair + n -= 0x10000 + s1 = 0xd800 | ((n >> 10) & 0x3ff) + s2 = 0xdc00 | (n & 0x3ff) + return '\\u{0:04x}\\u{1:04x}'.format(s1, s2) + # return '\\u%04x\\u%04x' % (s1, s2) + + return '"' + str(ESCAPE_ASCII.sub(replace, s)) + '"' + + +encode_basestring_ascii = ( + c_encode_basestring_ascii or py_encode_basestring_ascii) + + +class ThriftJSONDecoder(json.JSONDecoder): + + def __init__(self, *args, **kwargs): + self._thrift_class = kwargs.pop('thrift_class') + super(ThriftJSONDecoder, self).__init__(*args, **kwargs) + + def decode(self, json_str): + if isinstance(json_str, dict): + dct = json_str + else: + dct = super(ThriftJSONDecoder, self).decode(json_str) + return self._convert(dct, TType.STRUCT, + # (self._thrift_class, self._thrift_class.thrift_spec)) + self._thrift_class) + + def _convert(self, val, ttype, ttype_info): + if ttype == TType.STRUCT: + if val is None: + ret = None + else: + # (thrift_class, thrift_spec) = ttype_info + thrift_class = ttype_info + thrift_spec = ttype_info.thrift_spec + ret = thrift_class() + for tag, field in thrift_spec.items(): + if field is None: + continue + # {1: (15, 'ad_ids', 10, False), 255: (12, 'Base', , False)} + # {1: (15, 'models', (12, ), False), 255: (12, 'BaseResp', , False)} + if len(field) <= 3: + (field_ttype, field_name, dummy) = field + field_ttype_info = None + else: + (field_ttype, field_name, field_ttype_info, dummy) = field + + if val is None or field_name not in val: + continue + converted_val = self._convert(val[field_name], field_ttype, field_ttype_info) + setattr(ret, field_name, converted_val) + elif ttype == TType.LIST: + if type(ttype_info) != tuple: # 说明是基础类型了, 无法在细分 + (element_ttype, element_ttype_info) = (ttype_info, None) + else: + (element_ttype, element_ttype_info) = ttype_info + if val is not None: + ret = [self._convert(x, element_ttype, element_ttype_info) for x in val] + else: + ret = None + + elif ttype == TType.SET: + if type(ttype_info) != tuple: # 说明是基础类型了, 无法在细分 + (element_ttype, element_ttype_info) = (ttype_info, None) + else: + (element_ttype, element_ttype_info) = ttype_info + if val is not None: + ret = set([self._convert(x, element_ttype, element_ttype_info) for x in val]) + else: + ret = None + + elif ttype == TType.MAP: + # key处理 + if type(ttype_info[0]) == tuple: + key_ttype, key_ttype_info = ttype_info[0] + else: + key_ttype, key_ttype_info = ttype_info[0], None + + # value处理 + if type(ttype_info[1]) != tuple: # 说明value为基础类型, 已不可在细分 + val_ttype = ttype_info[1] + val_ttype_info = None + else: + val_ttype, val_ttype_info = ttype_info[1] + + if val is not None: + ret = dict([(self._convert(k, key_ttype, key_ttype_info), + self._convert(v, val_ttype, val_ttype_info)) for (k, v) in val.items()]) + else: + ret = None + elif ttype == TType.STRING: + if isinstance(val, str): + ret = val.encode("utf8") + elif val is None: + ret = None + else: + ret = str(val) + # 判断string字段是否是base64编码后的string, 如果是则此处需要对该string字段进行b64decode, 还原成原本的字符串 + # todo : 留待实现 + + elif ttype == TType.DOUBLE: + if val is not None: + ret = float(val) + else: + ret = None + elif ttype == TType.I64: + if val is not None: + ret = int(val) + else: + ret = None + elif ttype == TType.I32 or ttype == TType.I16 or ttype == TType.BYTE: + if val is not None: + ret = int(val) + else: + ret = None + elif ttype == TType.BOOL: + if val is not None: + ret = bool(val) + else: + ret = None + else: + raise TypeError('Unrecognized thrift field type: %s' % ttype) + return ret + + +def json2thrift(json_str, thrift_class): + logging.debug(json_str) + return json.loads(json_str, cls=ThriftJSONDecoder, thrift_class=thrift_class, strict=False) + + +def dumper(obj): + try: + return json.dumps(obj, default=lambda o: o.__dict__, sort_keys=True, indent=2) + except: + return obj.__dict__ + + +class MyJSONEncoder(json.JSONEncoder): + def __init__(self, skipkeys=False, ensure_ascii=True, check_circular=True, + allow_nan=True, indent=None, separators=None, + encoding='utf-8', default=None, sort_keys=False, **kw): + super(MyJSONEncoder, self).__init__(skipkeys=skipkeys, ensure_ascii=ensure_ascii, + check_circular=check_circular, allow_nan=allow_nan, indent=indent, + separators=separators, encoding=encoding, default=default, + sort_keys=sort_keys) + self.skip_nonutf8_value = kw.get('skip_nonutf8_value', False) # 默认不skip忽略非utf-8编码的字段 + + def encode(self, o): + """Return a JSON string representation of a Python data structure. + JSONEncoder().encode({"foo": ["bar", "baz"]}) + '{"foo": ["bar", "baz"]}' + + """ + # This is for extremely simple cases and benchmarks. + + if isinstance(o, str): + + if isinstance(o, str): + _encoding = self.encoding + if (_encoding is not None + and not (_encoding == 'utf-8')): + o = o.decode(_encoding) + if self.ensure_ascii: + return encode_basestring_ascii(o) + else: + return encode_basestring(o) + # This doesn't pass the iterator directly to ''.join() because the + # exceptions aren't as detailed. The list call should be roughly + # equivalent to the PySequence_Fast that ''.join() would do. + chunks = self.iterencode(o, _one_shot=True) + if not isinstance(chunks, (list, tuple)): + chunks = list(chunks) + # add by braver(braver@bytedance.com) + # todo: fix 'utf8' codec can't decode byte 0x91 in position 3: invalid start byte" + if self.skip_nonutf8_value: # 缺省为false + tmp_chunks = [] + for chunk in chunks: + try: + tmp_chunks.append(unicode_2_utf8_keep_native(chunk)) + except Exception as err: + logging.debug(traceback.format_exc()) + return ''.join(tmp_chunks) + + # 保留老的逻辑, /usr/lib/python2.7/package/json/__init__.py dumps接口 + return ''.join(chunks) + + +class ThriftJSONEncoder(json.JSONEncoder): + """ + add by braver(Braver@bytedance.com) + """ + + def __init__(self, skipkeys=False, ensure_ascii=True, check_circular=True, + allow_nan=True, indent=None, separators=None, default=None, sort_keys=False, **kw): + + super(ThriftJSONEncoder, self).__init__(skipkeys=skipkeys, ensure_ascii=ensure_ascii, + check_circular=check_circular, allow_nan=allow_nan, indent=indent, + separators=separators, default=default, sort_keys=sort_keys) + self.skip_nonutf8_value = kw.get('skip_nonutf8_value', False) # 默认不skip忽略非utf-8编码的字段 + + def encode(self, o): + """Return a JSON string representation of a Python data structure. + JSONEncoder().encode({"foo": ["bar", "baz"]}) + '{"foo": ["bar", "baz"]}' + + """ + # This is for extremely simple cases and benchmarks. + + if isinstance(o, str): + if isinstance(o, str): + _encoding = self.encoding + if (_encoding is not None + and not (_encoding == 'utf-8')): + o = o.decode(_encoding) + if self.ensure_ascii: + return encode_basestring_ascii(o) + else: + return encode_basestring(o) + # This doesn't pass the iterator directly to ''.join() because the + # exceptions aren't as detailed. The list call should be roughly + # equivalent to the PySequence_Fast that ''.join() would do. + chunks = self.iterencode(o, _one_shot=True) + if not isinstance(chunks, (list, tuple)): + chunks = list(chunks) + # add by braver(braver@bytedance.com) + # todo: fix 'utf8' codec can't decode byte 0x91 in position 3: invalid start byte" + if self.skip_nonutf8_value: # 缺省为false + tmp_chunks = [] + for chunk in chunks: + try: + tmp_chunks.append(unicode_2_utf8_keep_native(chunk)) + except Exception as err: + logging.debug(traceback.format_exc()) + return ''.join(tmp_chunks) + + # 保留老的逻辑, /usr/lib/python2.7/package/json/__init__.py dumps接口 + return ''.join(chunks) + + def default(self, o): + if isinstance(o, bytes): + return str(o, encoding='utf-8') + if not hasattr(o, 'thrift_spec'): + return super(ThriftJSONEncoder, self).default(o) + + spec = getattr(o, 'thrift_spec') + ret = {} + for tag, field in spec.items(): + if field is None: + continue + # (tag, field_ttype, field_name, field_ttype_info, default) = field + field_name = field[1] + default = field[-1] + field_type = field[0] + field_ttype_info = field[2] + # if field_type in [TType.STRING, TType.BINARY]: # 说明是string(明文string或者binary) + # if field_type in [TType.STRING, TType.BYTE]: # 说明是string(明文string或者binary) + if field_name in o.__dict__: + val = o.__dict__[field_name] + if field_type in [TType.LIST, TType.SET]: # 数组类型 + if val: # val为非空数组/Set + val = list(val) # 统一转成数组(list/set) + is_need_binary_bs64 = False + if type(field_ttype_info) != tuple: # 基础类型 + if field_ttype_info in [TType.BYTE] and type(val[0]) in [str] and not istext( + val[0]): + is_need_binary_bs64 = True + if is_need_binary_bs64: + for index, item in enumerate(val): + if item and type(item) in [str] and not istext(item): + val[index] = base64.b64encode(item) # 判断为二进制字符串, 需要进行base64编码 + if field_type in [TType.BYTE] and type(val) in [str]: # 说明是string(明文string或者binary) + # 需要对二进制字节字符串字段进行base64编码, 将二进制字节串字段->ascii字符编码的base64编码明文串 + if val and not istext(val): # 说明是该字段非空且为binary string + print('4' * 100, val) + val = base64.b64encode(val.encode('utf-8')) + # val = base64.b64encode(val) # 进行base64编码处理, 不然该字段序列化为json时会报错 + # if val != default: + ret[field_name] = val + if 'request_id' in o.__dict__: + ret['request_id'] = o.__dict__['request_id'] + if 'rpc_latency' in o.__dict__: + ret['rpc_latency'] = o.__dict__['rpc_latency'] + return ret + + +def thrift2json(obj, skip_nonutf8_value=False): + return json.dumps(obj, cls=ThriftJSONEncoder, ensure_ascii=False, skip_nonutf8_value=skip_nonutf8_value) + + +def thrift2dict(obj): + str = thrift2json(obj) + return json.loads(str) + + +dict2thrift = json2thrift + +if __name__ == '__main__': + print(istext("Всего за {$price$}, а доставка - бесплатно!")) + print(istext(b'\xe4\xb8\xad\xe6\x96\x87')) + print(istext( + '{"web_uri":"ad-site-i18n-sg/202103185d0d723d88b7f642452dac73","height":336,"width":336,"file_name":""}')) diff --git a/httprunner/thrift/thrift_client.py b/httprunner/thrift/thrift_client.py new file mode 100644 index 00000000..72b5c98d --- /dev/null +++ b/httprunner/thrift/thrift_client.py @@ -0,0 +1,101 @@ +# -*- coding: utf-8 -*- +from __future__ import absolute_import +import enum +import json + +from loguru import logger +import thriftpy2 +from thriftpy2.protocol import (TBinaryProtocolFactory, TCompactProtocolFactory, TCyBinaryProtocolFactory, + TJSONProtocolFactory) +from thriftpy2.rpc import make_client +from thriftpy2.transport import (TBufferedTransportFactory, TCyBufferedTransportFactory, TCyFramedTransportFactory, + TFramedTransportFactory) +from thriftpy2.utils import deserialize + +from httprunner.thrift.data_convertor import json2thrift, thrift2json, thrift2dict + + +class ProtoType(enum.Enum): + pBinary = 1 + pCyBinary = 2 + pCompact = 3 + pJson = 4 + + +class TransType(enum.Enum): + tBuffered = 1 + tCyBuffered = 2 + tFramed = 3 + tCyFramed = 4 + + +class RequestFormat(enum.Enum): + json = 1 + binary = 2 + + +def get_proto_factory(proto_type): + if proto_type == ProtoType.pBinary: + return TBinaryProtocolFactory() + if proto_type == ProtoType.pCyBinary: + return TCyBinaryProtocolFactory() + if proto_type == ProtoType.pCompact: + return TCompactProtocolFactory() + if proto_type == ProtoType.pJson: + return TJSONProtocolFactory() + + +def get_trans_factory(trans_type): + if trans_type == TransType.tBuffered: + return TBufferedTransportFactory() + if trans_type == TransType.tCyBuffered: + return TCyBufferedTransportFactory() + if trans_type == TransType.tFramed: + return TFramedTransportFactory() + if trans_type == TransType.tCyFramed: + return TCyFramedTransportFactory() + + +class ThriftClient(object): + + def __init__(self, thrift_file, service_name, ip, port, include_dirs=None, timeout=3000, proto_type=ProtoType.pCyBinary, + trans_type=TransType.tCyBuffered): + self.thrift_file = thrift_file + self.include_dirs = include_dirs + self.service_name = service_name + self.ip = ip + self.port = port + self.timeout = timeout + self.proto_type = proto_type + self.trans_type = trans_type + try: + logger.debug('init thrift module: thrift_file=%s, module_name=%s', thrift_file, + str(self.service_name) + '_thrift') + self.thrift_module = thriftpy2.load(self.thrift_file, module_name=str(self.service_name) + '_thrift', + include_dirs=self.include_dirs) + self.thrift_service_obj = getattr(self.thrift_module, self.service_name) + logger.debug('init thrift client: service_name=%s, ip=%s, port=%s', self.thrift_service_obj, ip, port) + self.client = make_client(self.thrift_service_obj, self.ip, int(self.port), timeout=self.timeout, + proto_factory=get_proto_factory(self.proto_type), + trans_factory=get_trans_factory(self.trans_type)) + except Exception as e: + self.thrift_module = None + self.thrift_service_obj = None + self.client = None + logger.exception('init thrift module and client failed: {}'.format(e)) + finally: + thriftpy2.parser.parser.thrift_stack = [] + + def get_client(self): + return self.client + + def send_request(self, request_data, request_method=''): + thrift_req_cls = getattr(self.thrift_service_obj, request_method + '_args').thrift_spec[1][2] + request_obj = json2thrift(json.dumps(request_data), thrift_req_cls) + logger.debug('send thrift request: request_method=%s, request_obj=%s', request_method, request_obj) + response_obj = getattr(self.client, request_method)(request_obj) + logger.debug('thrift response = %s', response_obj) + return thrift2dict(response_obj) + + def __del__(self): + self.client.close()