add:sql and thrift as step

This commit is contained in:
duanchao.bill
2022-04-26 17:59:20 +08:00
parent 9a0ffa9802
commit 4b9433fa72
9 changed files with 575 additions and 65 deletions

View File

@@ -7,6 +7,8 @@ from httprunner.runner import HttpRunner
from httprunner.step import Step
from httprunner.step_request import RunRequest
from httprunner.step_testcase import RunTestCase
from httprunner.step_sql_request import RunSqlRequest, StepSqlRequestValidation, StepSqlRequestExtraction
from httprunner.step_thrift_request import RunThriftRequest, StepThriftRequestValidation, StepThriftRequestExtraction
__all__ = [
"__version__",
@@ -15,6 +17,12 @@ __all__ = [
"Config",
"Step",
"RunRequest",
"RunSqlRequest",
"StepSqlRequestValidation",
"StepSqlRequestExtraction",
"RunThriftRequest",
"StepThriftRequestValidation",
"StepThriftRequestExtraction",
"RunTestCase",
"Parameters",
]

View File

@@ -1,7 +1,7 @@
import inspect
from typing import Text
from httprunner.models import TConfig, TConfigThrift
from httprunner.models import TConfig, TConfigThrift, TConfigDB, ProtoType
class ConfigThrift(object):
@@ -21,8 +21,65 @@ class ConfigThrift(object):
self.__config.thrift.cluster = cluster
return self
def target(self, target: Text) -> "ConfigThrift":
self.__config.thrift.target = target
def service_name(self, service_name: Text) -> "ConfigThrift":
self.__config.thrift.service_name = service_name
return self
def method(self, method: Text) -> "ConfigThrift":
self.__config.thrift.method = method
return self
def ip(self, service_name_: Text) -> "ConfigThrift":
self.__config.thrift.service_name = service_name_
return self
def port(self, port: int) -> "ConfigThrift":
self.__config.thrift.port = port
return self
def timeout(self, timeout: int) -> "ConfigThrift":
self.__config.thrift.timeout = timeout
return self
def proto_type(self, proto_type: ProtoType) -> "ConfigThrift":
self.__config.thrift.proto_type = proto_type
return self
def trans_type(self, trans_type: ProtoType) -> "ConfigThrift":
self.__config.thrift.trans_type = trans_type
return self
def struct(self) -> TConfig:
return self.__config
class ConfigDB(object):
def __init__(self, config: TConfig):
self.__config = config
self.__config.db = TConfigDB()
def psm(self, psm):
self.__config.db.psm = psm
return self
def user(self, user):
self.__config.db.user = user
return self
def password(self, password):
self.__config.db.password = password
return self
def ip(self, ip):
self.__config.db.ip = ip
return self
def port(self, port: int):
self.__config.db.port = port
return self
def database(self, database: Text):
self.__config.db.database = database
return self
def struct(self) -> TConfig:
@@ -64,3 +121,6 @@ class Config(object):
def thrift(self) -> ConfigThrift:
return ConfigThrift(self.__config)
def db(self) -> ConfigDB:
return ConfigDB(self.__config)

View File

@@ -0,0 +1,79 @@
# -*- coding: utf-8 -*-
import datetime
import json
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
class DBEngine(object):
def __init__(self, db_uri):
"""
db_uri = f'mysql+pymysql://{username}:{password}@{host}:{port}/{database}?charset=utf8mb4'
"""
engine = create_engine(db_uri)
self.session = sessionmaker(bind=engine)()
@staticmethod
def value_decode(row: dict):
"""
Try to decode value of table
datetime.datetime-->string
datetime.date-->string
json str-->dict
:param row:
:return:
"""
for k, v in row.items():
if isinstance(v, datetime.datetime):
row[k] = v.strftime('%Y-%m-%d %H:%M:%S')
elif isinstance(v, datetime.date):
row[k] = v.strftime("%Y-%m-%d")
elif isinstance(v, str):
try:
row[k] = json.loads(v)
except ValueError:
pass
def _fetch(self, query, size=-1, commit=True):
result = self.session.execute(query)
self.session.commit() if commit else 0
if query.upper()[:6] == "SELECT":
if size < 0:
al = result.fetchall()
al = [dict(el) for el in al]
return al or None
elif size == 1:
on = dict(result.fetchone())
self.value_decode(on)
return on or None
else:
mny = result.fetchmany(size)
mny = [dict(el) for el in mny]
return mny or None
elif query.upper()[:6] in ("UPDATE", "DELETE", "INSERT"):
return {"rowcount": result.rowcount}
def fetchone(self, query, commit=True):
return self._fetch(query, size=1, commit=commit)
def fetchmany(self, query, size, commit=True):
return self._fetch(query=query, size=size, commit=commit)
def fetchall(self, query, commit=True):
return self._fetch(query=query, size=-1, commit=commit)
def insert(self, query, commit=True):
return self._fetch(query=query, commit=commit)
def delete(self, query, commit=True):
return self._fetch(query=query, commit=commit)
def update(self, query, commit=True):
return self._fetch(query=query, commit=commit)
if __name__ == '__main__':
db = DBEngine(
f"mysql+pymysql://xxxxx:xxxxx@10.0.0.1:3306/dbname?charset=utf8mb4")

View File

@@ -86,3 +86,7 @@ class TestcaseNotFound(NotFoundError):
class SummaryEmpty(MyBaseError):
"""test result summary data is empty"""
class SqlMethodNotSupport(MyBaseError):
pass

View File

@@ -28,6 +28,20 @@ class MethodEnum(Text, Enum):
PATCH = "PATCH"
class ProtoType(Enum):
pBinary = 1
pCyBinary = 2
pCompact = 3
pJson = 4
class TransType(Enum):
tBuffered = 1
tCyBuffered = 2
tFramed = 3
tCyFramed = 4
# configs for thrift rpc
class TConfigThrift(BaseModel):
psm: Text = None
@@ -36,6 +50,66 @@ class TConfigThrift(BaseModel):
target: Text = None
include_dirs: List[Text] = None
thrift_client: Any = None
timeout: int = 10
idl_path: Text = None
method: Text = None
ip: Text = "127.0.0.1"
port: int = 9000
service_name: Text = None
proto_type: ProtoType = ProtoType.pBinary
trans_type: TransType = TransType.tBuffered
# configs for db
class TConfigDB(BaseModel):
psm: Text = None
user: Text = None
password: Text = None
ip: Text = None
port: int = 3306
database: Text = None
class TransportEnum(Text, Enum):
BUFFERED = "buffered"
FRAMED = "framed"
class TThriftRequest(BaseModel):
""" rpc request model"""
method: Text = ''
params: Dict = {}
thrift_client: Any = None
idl_path: Text = '' # idl local path
timeout: int = 10 # sec
transport: TransportEnum = TransportEnum.BUFFERED
include_dirs: List[Union[Text, None]] = [] # param of thriftpy2.load
target: Text = "" # tcp://{ip}:{port} or sd://psm?cluster=xx&env=xx
env: Text = "prod"
cluster: Text = "default"
psm: Text = ""
service_name: Text = None
ip: Text = None
port: int = None
proto_type: ProtoType = None
trans_type: TransType = None
class SqlMethodEnum(Text, Enum):
FETCHONE = "FETCHONE"
FETCHMANY = "FETCHMANY"
FETCHALL = "FETCHALL"
INSERT = "INSERT"
UPDATE = "UPDATE"
DELETE = "DELETE"
class TSqlRequest(BaseModel):
""" sql request model"""
db_config: TConfigDB = TConfigDB()
method: SqlMethodEnum = None
sql: Text = None
size: int = 0 # limit nums of sql result
class TConfig(BaseModel):
@@ -51,6 +125,7 @@ class TConfig(BaseModel):
path: Text = None
# configs for other protocols
thrift: TConfigThrift = None
db: TConfigDB = TConfigDB()
class TRequest(BaseModel):
@@ -84,6 +159,8 @@ class TStep(BaseModel):
validate_script: List[Text] = []
retry_times: int = 0
retry_interval: int = 0 # sec
thrift_request: Union[TThriftRequest, None] = None
sql_request: Union[TSqlRequest, None] = None
class TestCase(BaseModel):

View File

@@ -1,18 +1,17 @@
from typing import Any, Dict, Text
from typing import Dict, Text, Any
import jmespath
import requests
from jmespath.exceptions import JMESPathError
from loguru import logger
from httprunner import exceptions
from httprunner.exceptions import ParamsError, ValidationFailure
from httprunner.models import Validators, VariablesMapping
from httprunner.parser import Parser, parse_string_value
from httprunner.exceptions import ValidationFailure, ParamsError
from httprunner.models import VariablesMapping, Validators
from httprunner.parser import parse_string_value, Parser
def get_uniform_comparator(comparator: Text):
"""convert comparator alias to uniform name"""
""" convert comparator alias to uniform name"""
if comparator in ["eq", "equals", "equal"]:
return "equal"
elif comparator in ["lt", "less_than"]:
@@ -113,9 +112,9 @@ def uniform_validator(validator):
}
class ResponseObject(object):
def __init__(self, resp_obj: requests.Response, parser: Parser):
"""initialize with a requests.Response object
class ResponseObjectBase(object):
def __init__(self, resp_obj, parser: Parser):
""" initialize with a response object
Args:
resp_obj (instance): requests.Response instance
@@ -125,71 +124,33 @@ class ResponseObject(object):
self.parser = parser
self.validation_results: Dict = {}
def __getattr__(self, key):
if key in ["json", "content", "body"]:
try:
value = self.resp_obj.json()
except ValueError:
value = self.resp_obj.content
elif key == "cookies":
value = self.resp_obj.cookies.get_dict()
else:
try:
value = getattr(self.resp_obj, key)
except AttributeError:
err_msg = "ResponseObject does not have attribute: {}".format(key)
logger.error(err_msg)
raise exceptions.ParamsError(err_msg)
self.__dict__[key] = value
return value
def _search_jmespath(self, expr: Text) -> Any:
resp_obj_meta = {
"status_code": self.status_code,
"headers": self.headers,
"cookies": self.cookies,
"body": self.body,
}
if not expr.startswith(tuple(resp_obj_meta.keys())):
return expr
try:
check_value = jmespath.search(expr, resp_obj_meta)
except JMESPathError as ex:
logger.error(
f"failed to search with jmespath\n"
f"expression: {expr}\n"
f"data: {resp_obj_meta}\n"
f"exception: {ex}"
)
raise
return check_value
def extract(
self,
extractors: Dict[Text, Text],
variables_mapping: VariablesMapping = None,
) -> Dict[Text, Any]:
def extract(self,
extractors: Dict[Text, Text],
variables_mapping: VariablesMapping = None,
) -> Dict[Text, Any]:
if not extractors:
return {}
extract_mapping = {}
for key, field in extractors.items():
if "$" in field:
if '$' in field:
# field contains variable or function
field = self.parser.parse_data(field, variables_mapping)
field = self.parser.parse_data(
field, variables_mapping
)
field_value = self._search_jmespath(field)
extract_mapping[key] = field_value
logger.info(f"extract mapping: {extract_mapping}")
return extract_mapping
def _search_jmespath(self, expr: Text) -> Any:
raise NotImplementedError("_search_jmespath not override")
def validate(
self,
validators: Validators,
variables_mapping: VariablesMapping = None,
self,
validators: Validators,
variables_mapping: VariablesMapping = None,
):
variables_mapping = variables_mapping or {}
@@ -212,7 +173,9 @@ class ResponseObject(object):
check_item = u_validator["check"]
if "$" in check_item:
# check_item is variable or function
check_item = self.parser.parse_data(check_item, variables_mapping)
check_item = self.parser.parse_data(
check_item, variables_mapping
)
check_item = parse_string_value(check_item)
if check_item and isinstance(check_item, Text):
@@ -274,3 +237,66 @@ class ResponseObject(object):
if not validate_pass:
failures_string = "\n".join([failure for failure in failures])
raise ValidationFailure(failures_string)
class ResponseObject(ResponseObjectBase):
def __getattr__(self, key):
if key in ["json", "content", "body"]:
try:
value = self.resp_obj.json()
except ValueError:
value = self.resp_obj.content
elif key == "cookies":
value = self.resp_obj.cookies.get_dict()
else:
try:
value = getattr(self.resp_obj, key)
except AttributeError:
err_msg = "ResponseObject does not have attribute: {}".format(key)
logger.error(err_msg)
raise exceptions.ParamsError(err_msg)
self.__dict__[key] = value
return value
def _search_jmespath(self, expr: Text) -> Any:
resp_obj_meta = {
"status_code": self.status_code,
"headers": self.headers,
"cookies": self.cookies,
"body": self.body,
}
if not expr.startswith(tuple(resp_obj_meta.keys())):
return expr
try:
check_value = jmespath.search(expr, resp_obj_meta)
except JMESPathError as ex:
logger.error(
f"failed to search with jmespath\n"
f"expression: {expr}\n"
f"data: {resp_obj_meta}\n"
f"exception: {ex}"
)
raise
return check_value
class ThriftResponseObject(ResponseObjectBase):
def _search_jmespath(self, expr: Text) -> Any:
try:
check_value = jmespath.search(expr, self.resp_obj)
except JMESPathError as ex:
logger.error(
f"failed to search with jmespath\n"
f"expression: {expr}\n"
f"data: {self.resp_obj}\n"
f"exception: {ex}"
)
raise
return check_value
class SqlResponseObject(ThriftResponseObject):
pass

View File

@@ -38,6 +38,8 @@ class SessionRunner(object):
session: HttpSession = None
case_id: Text = ""
root_dir: Text = ""
thrift_client = None
db_engine = None
__config: TConfig
__project_meta: ProjectMeta = None
@@ -87,6 +89,12 @@ class SessionRunner(object):
self.__export = export
return self
def with_thrift_client(self, thrift_client) -> "SessionRunner":
self.thrift_client = thrift_client
def with_db_engine(self,db_engine):
self.db_engine = db_engine
def __parse_config(self, param: Dict = None) -> None:
# parse config variables
self.__config.variables.update(self.__session_variables)

View File

@@ -8,6 +8,7 @@ from httprunner.step_request import (
StepRequestValidation,
)
from httprunner.step_testcase import StepRefCase
from httprunner.step_sql_request import RunSqlRequest, StepSqlRequestValidation, StepSqlRequestExtraction
class Step(object):
@@ -18,6 +19,9 @@ class Step(object):
StepRequestExtraction,
RequestWithOptionalArgs,
StepRefCase,
RunSqlRequest,
StepSqlRequestValidation,
StepSqlRequestExtraction,
],
):
self.__step = step

View File

@@ -0,0 +1,244 @@
# -*- coding: utf-8 -*-
import time
from typing import Text
from loguru import logger
from httprunner import utils
from httprunner.exceptions import ValidationFailure
from httprunner.models import IStep, StepResult, TStep
from httprunner.models import TSqlRequest, SqlMethodEnum
from httprunner.response import SqlResponseObject
from httprunner.runner import HttpRunner
from httprunner.step_request import call_hooks, StepRequestExtraction, StepRequestValidation
from httprunner.database.engine import DBEngine
from httprunner.exceptions import SqlMethodNotSupport
def run_step_sql_request(runner: HttpRunner, step: TStep) -> StepResult:
"""run teststep:sql request"""
start_time = time.time()
step_result = StepResult(
name=step.name,
success=False,
)
step.variables = runner.merge_step_variables(step.variables)
# parse
request_dict = step.sql_request.dict()
parsed_request_dict = runner.parser.parse_data(
request_dict, step.variables
)
config = runner.get_config()
parsed_request_dict["db_config"]["psm"] = parsed_request_dict["db_config"]["psm"] or config.db.psm
parsed_request_dict["db_config"]["user"] = parsed_request_dict["db_config"]["user"] or config.db.user
parsed_request_dict["db_config"]["password"] = parsed_request_dict["db_config"]["password"] or config.db.password
parsed_request_dict["db_config"]["ip"] = parsed_request_dict["db_config"]["ip"] or config.db.ip
parsed_request_dict["db_config"]["port"] = parsed_request_dict["db_config"]["port"] or config.db.port
parsed_request_dict["db_config"]["database"] = parsed_request_dict["db_config"]["database"] or config.db.database
if parsed_request_dict["db_config"]["psm"]:
runner.db_engine = DBEngine(f'mysql+pymysql://:@/?charset=utf8mb4&db_psm={parsed_request_dict["psm"]}')
else:
runner.db_engine = DBEngine(
f'mysql+pymysql://{parsed_request_dict["db_config"]["user"]}:'
f'{parsed_request_dict["db_config"]["password"]}@{parsed_request_dict["db_config"]["ip"]}:'
f'{parsed_request_dict["db_config"]["port"]}/{parsed_request_dict["db_config"]["database"]}'
f'?charset=utf8mb4')
# parsed_request_dict["headers"].setdefault(
# "HRUN-Request-ID",
# f"HRUN-{self.__case_id}-{str(int(time.time() * 1000))[-6:]}",
# )
# setup hooks
if step.setup_hooks:
call_hooks(runner, step.setup_hooks, step.variables, "setup request")
logger.info(f"Executing SQL: {parsed_request_dict['sql']}")
if step.sql_request.method == SqlMethodEnum.FETCHONE:
sql_resp = runner.db_engine.fetchone(parsed_request_dict['sql'])
elif step.sql_request.method == SqlMethodEnum.INSERT:
sql_resp = runner.db_engine.insert(parsed_request_dict['sql'])
elif step.sql_request.method == SqlMethodEnum.FETCHMANY:
sql_resp = runner.db_engine.fetchmany(parsed_request_dict['sql'], parsed_request_dict['size'])
elif step.sql_request.method == SqlMethodEnum.FETCHALL:
sql_resp = runner.db_engine.fetchall(parsed_request_dict['sql'])
elif step.sql_request.method == SqlMethodEnum.UPDATE:
sql_resp = runner.db_engine.update(parsed_request_dict['sql'])
elif step.sql_request.method == SqlMethodEnum.DELETE:
sql_resp = runner.db_engine.delete(parsed_request_dict['sql'])
else:
raise SqlMethodNotSupport(f"step.sql_request.method {parsed_request_dict['method']} not support")
resp_obj = SqlResponseObject(sql_resp, parser=runner.parser)
step.variables["sql_response"] = resp_obj
# teardown hooks
if step.teardown_hooks:
call_hooks(runner, step.teardown_hooks, step.variables, "teardown request")
def log_sql_req_resp_details():
err_msg = "\n{} SQL DETAILED REQUEST & RESPONSE {}\n".format("*" * 32, "*" * 32)
# log request
err_msg += "====== sql request details ======\n"
err_msg += f"sql: {step.sql_request.sql}\n"
for k, v in parsed_request_dict.items():
v = utils.omit_long_data(v)
err_msg += f"{k}: {repr(v)}\n"
err_msg += "\n"
# log response
err_msg += "====== sql response details ======\n"
for k, v in sql_resp.items():
v = utils.omit_long_data(v)
err_msg += f"{k}: {repr(v)}\n"
logger.error(err_msg)
# extract
extractors = step.extract
extract_mapping = resp_obj.extract(extractors)
step_result.export_vars = extract_mapping
variables_mapping = step.variables
variables_mapping.update(extract_mapping)
# validate
validators = step.validators
try:
resp_obj.validate(
validators, variables_mapping
)
step_result.success = True
except ValidationFailure:
log_sql_req_resp_details()
raise
finally:
session_data = runner.session.data
session_data.success = step_result.success
session_data.validators = resp_obj.validation_results
# save step data
step_result.data = session_data
step_result.elapsed = time.time() - start_time
return step_result
class StepSqlRequestValidation(StepRequestValidation):
def __init__(self, step: TStep):
self.__step = step
super().__init__(step)
def run(self, runner: HttpRunner):
return run_step_sql_request(runner, self.__step)
class StepSqlRequestExtraction(StepRequestExtraction):
def __init__(self, step: TStep):
self.__step = step
super().__init__(step)
def run(self, runner: HttpRunner):
return run_step_sql_request(runner, self.__step)
def validate(self) -> StepSqlRequestValidation:
return StepSqlRequestValidation(self.__step)
class RunSqlRequest(IStep):
def __init__(self, name: Text):
self.__step = TStep(name=name)
self.__step.sql_request = TSqlRequest()
def with_variables(self, **variables) -> "RunSqlRequest":
self.__step.variables.update(variables)
return self
def with_db_config(self, psm=None, user=None, password=None, ip=None, port=None, database=None):
if psm:
self.__step.sql_request.db_config.psm = psm
if user:
self.__step.sql_request.db_config.user = user
if password:
self.__step.sql_request.db_config.password = password
if ip:
self.__step.sql_request.db_config.ip = ip
if port:
self.__step.sql_request.db_config.port = port
if database:
self.__step.sql_request.db_config.database = database
return self
def fetchone(self, sql) -> "RunSqlRequest":
self.__step.sql_request.method = SqlMethodEnum.FETCHONE
self.__step.sql_request.sql = sql
return self
def fetchmany(self, sql, size) -> "RunSqlRequest":
self.__step.sql_request.method = SqlMethodEnum.FETCHMANY
self.__step.sql_request.sql = sql
self.__step.sql_request.size = size
return self
def fetchall(self, sql) -> "RunSqlRequest":
self.__step.sql_request.method = SqlMethodEnum.FETCHALL
self.__step.sql_request.sql = sql
return self
def update(self, sql) -> "RunSqlRequest":
self.__step.sql_request.method = SqlMethodEnum.UPDATE
self.__step.sql_request.sql = sql
return self
def delete(self, sql) -> "RunSqlRequest":
self.__step.sql_request.method = SqlMethodEnum.DELETE
self.__step.sql_request.sql = sql
return self
def insert(self, sql) -> "RunSqlRequest":
self.__step.sql_request.method = SqlMethodEnum.INSERT
self.__step.sql_request.sql = sql
return self
def with_retry(self, retry_times, retry_interval) -> "RunSqlRequest":
self.__step.retry_times = retry_times
self.__step.retry_interval = retry_interval
return self
def teardown_hook(self, hook: Text, assign_var_name: Text = None) -> "RunSqlRequest":
if assign_var_name:
self.__step.teardown_hooks.append({assign_var_name: hook})
else:
self.__step.teardown_hooks.append(hook)
return self
def setup_hook(self, hook: Text, assign_var_name: Text = None) -> "RunSqlRequest":
if assign_var_name:
self.__step.setup_hooks.append({assign_var_name: hook})
else:
self.__step.setup_hooks.append(hook)
return self
def struct(self) -> TStep:
return self.__step
def name(self) -> Text:
return self.__step.name
def type(self) -> Text:
return f"sql-request-{self.__step.sql_request.sql}"
def run(self, runner) -> StepResult:
return run_step_sql_request(runner, self.__step)
def extract(self) -> StepSqlRequestExtraction:
return StepSqlRequestExtraction(self.__step)
def validate(self) -> StepSqlRequestValidation:
return StepSqlRequestValidation(self.__step)
def with_jmespath(self, jmes_path: Text, var_name: Text) -> "StepSqlRequestExtraction":
self.__step.extract[var_name] = jmes_path
return StepSqlRequestExtraction(self.__step)