mirror of
https://github.com/httprunner/httprunner.git
synced 2026-05-12 02:21:29 +08:00
add:sql and thrift as step
This commit is contained in:
@@ -7,6 +7,8 @@ from httprunner.runner import HttpRunner
|
||||
from httprunner.step import Step
|
||||
from httprunner.step_request import RunRequest
|
||||
from httprunner.step_testcase import RunTestCase
|
||||
from httprunner.step_sql_request import RunSqlRequest, StepSqlRequestValidation, StepSqlRequestExtraction
|
||||
from httprunner.step_thrift_request import RunThriftRequest, StepThriftRequestValidation, StepThriftRequestExtraction
|
||||
|
||||
__all__ = [
|
||||
"__version__",
|
||||
@@ -15,6 +17,12 @@ __all__ = [
|
||||
"Config",
|
||||
"Step",
|
||||
"RunRequest",
|
||||
"RunSqlRequest",
|
||||
"StepSqlRequestValidation",
|
||||
"StepSqlRequestExtraction",
|
||||
"RunThriftRequest",
|
||||
"StepThriftRequestValidation",
|
||||
"StepThriftRequestExtraction",
|
||||
"RunTestCase",
|
||||
"Parameters",
|
||||
]
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import inspect
|
||||
from typing import Text
|
||||
|
||||
from httprunner.models import TConfig, TConfigThrift
|
||||
from httprunner.models import TConfig, TConfigThrift, TConfigDB, ProtoType
|
||||
|
||||
|
||||
class ConfigThrift(object):
|
||||
@@ -21,8 +21,65 @@ class ConfigThrift(object):
|
||||
self.__config.thrift.cluster = cluster
|
||||
return self
|
||||
|
||||
def target(self, target: Text) -> "ConfigThrift":
|
||||
self.__config.thrift.target = target
|
||||
def service_name(self, service_name: Text) -> "ConfigThrift":
|
||||
self.__config.thrift.service_name = service_name
|
||||
return self
|
||||
|
||||
def method(self, method: Text) -> "ConfigThrift":
|
||||
self.__config.thrift.method = method
|
||||
return self
|
||||
|
||||
def ip(self, service_name_: Text) -> "ConfigThrift":
|
||||
self.__config.thrift.service_name = service_name_
|
||||
return self
|
||||
|
||||
def port(self, port: int) -> "ConfigThrift":
|
||||
self.__config.thrift.port = port
|
||||
return self
|
||||
|
||||
def timeout(self, timeout: int) -> "ConfigThrift":
|
||||
self.__config.thrift.timeout = timeout
|
||||
return self
|
||||
|
||||
def proto_type(self, proto_type: ProtoType) -> "ConfigThrift":
|
||||
self.__config.thrift.proto_type = proto_type
|
||||
return self
|
||||
|
||||
def trans_type(self, trans_type: ProtoType) -> "ConfigThrift":
|
||||
self.__config.thrift.trans_type = trans_type
|
||||
return self
|
||||
|
||||
def struct(self) -> TConfig:
|
||||
return self.__config
|
||||
|
||||
|
||||
class ConfigDB(object):
|
||||
def __init__(self, config: TConfig):
|
||||
self.__config = config
|
||||
self.__config.db = TConfigDB()
|
||||
|
||||
def psm(self, psm):
|
||||
self.__config.db.psm = psm
|
||||
return self
|
||||
|
||||
def user(self, user):
|
||||
self.__config.db.user = user
|
||||
return self
|
||||
|
||||
def password(self, password):
|
||||
self.__config.db.password = password
|
||||
return self
|
||||
|
||||
def ip(self, ip):
|
||||
self.__config.db.ip = ip
|
||||
return self
|
||||
|
||||
def port(self, port: int):
|
||||
self.__config.db.port = port
|
||||
return self
|
||||
|
||||
def database(self, database: Text):
|
||||
self.__config.db.database = database
|
||||
return self
|
||||
|
||||
def struct(self) -> TConfig:
|
||||
@@ -64,3 +121,6 @@ class Config(object):
|
||||
|
||||
def thrift(self) -> ConfigThrift:
|
||||
return ConfigThrift(self.__config)
|
||||
|
||||
def db(self) -> ConfigDB:
|
||||
return ConfigDB(self.__config)
|
||||
|
||||
79
httprunner/database/engine.py
Normal file
79
httprunner/database/engine.py
Normal file
@@ -0,0 +1,79 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
import datetime
|
||||
import json
|
||||
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
|
||||
class DBEngine(object):
|
||||
def __init__(self, db_uri):
|
||||
"""
|
||||
db_uri = f'mysql+pymysql://{username}:{password}@{host}:{port}/{database}?charset=utf8mb4'
|
||||
|
||||
"""
|
||||
engine = create_engine(db_uri)
|
||||
self.session = sessionmaker(bind=engine)()
|
||||
|
||||
@staticmethod
|
||||
def value_decode(row: dict):
|
||||
"""
|
||||
Try to decode value of table
|
||||
datetime.datetime-->string
|
||||
datetime.date-->string
|
||||
json str-->dict
|
||||
:param row:
|
||||
:return:
|
||||
"""
|
||||
for k, v in row.items():
|
||||
if isinstance(v, datetime.datetime):
|
||||
row[k] = v.strftime('%Y-%m-%d %H:%M:%S')
|
||||
elif isinstance(v, datetime.date):
|
||||
row[k] = v.strftime("%Y-%m-%d")
|
||||
elif isinstance(v, str):
|
||||
try:
|
||||
row[k] = json.loads(v)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
def _fetch(self, query, size=-1, commit=True):
|
||||
result = self.session.execute(query)
|
||||
self.session.commit() if commit else 0
|
||||
if query.upper()[:6] == "SELECT":
|
||||
if size < 0:
|
||||
al = result.fetchall()
|
||||
al = [dict(el) for el in al]
|
||||
return al or None
|
||||
elif size == 1:
|
||||
on = dict(result.fetchone())
|
||||
self.value_decode(on)
|
||||
return on or None
|
||||
else:
|
||||
mny = result.fetchmany(size)
|
||||
mny = [dict(el) for el in mny]
|
||||
return mny or None
|
||||
elif query.upper()[:6] in ("UPDATE", "DELETE", "INSERT"):
|
||||
return {"rowcount": result.rowcount}
|
||||
|
||||
def fetchone(self, query, commit=True):
|
||||
return self._fetch(query, size=1, commit=commit)
|
||||
|
||||
def fetchmany(self, query, size, commit=True):
|
||||
return self._fetch(query=query, size=size, commit=commit)
|
||||
|
||||
def fetchall(self, query, commit=True):
|
||||
return self._fetch(query=query, size=-1, commit=commit)
|
||||
|
||||
def insert(self, query, commit=True):
|
||||
return self._fetch(query=query, commit=commit)
|
||||
|
||||
def delete(self, query, commit=True):
|
||||
return self._fetch(query=query, commit=commit)
|
||||
|
||||
def update(self, query, commit=True):
|
||||
return self._fetch(query=query, commit=commit)
|
||||
|
||||
if __name__ == '__main__':
|
||||
db = DBEngine(
|
||||
f"mysql+pymysql://xxxxx:xxxxx@10.0.0.1:3306/dbname?charset=utf8mb4")
|
||||
|
||||
@@ -86,3 +86,7 @@ class TestcaseNotFound(NotFoundError):
|
||||
|
||||
class SummaryEmpty(MyBaseError):
|
||||
"""test result summary data is empty"""
|
||||
|
||||
|
||||
class SqlMethodNotSupport(MyBaseError):
|
||||
pass
|
||||
|
||||
@@ -28,6 +28,20 @@ class MethodEnum(Text, Enum):
|
||||
PATCH = "PATCH"
|
||||
|
||||
|
||||
class ProtoType(Enum):
|
||||
pBinary = 1
|
||||
pCyBinary = 2
|
||||
pCompact = 3
|
||||
pJson = 4
|
||||
|
||||
|
||||
class TransType(Enum):
|
||||
tBuffered = 1
|
||||
tCyBuffered = 2
|
||||
tFramed = 3
|
||||
tCyFramed = 4
|
||||
|
||||
|
||||
# configs for thrift rpc
|
||||
class TConfigThrift(BaseModel):
|
||||
psm: Text = None
|
||||
@@ -36,6 +50,66 @@ class TConfigThrift(BaseModel):
|
||||
target: Text = None
|
||||
include_dirs: List[Text] = None
|
||||
thrift_client: Any = None
|
||||
timeout: int = 10
|
||||
idl_path: Text = None
|
||||
method: Text = None
|
||||
ip: Text = "127.0.0.1"
|
||||
port: int = 9000
|
||||
service_name: Text = None
|
||||
proto_type: ProtoType = ProtoType.pBinary
|
||||
trans_type: TransType = TransType.tBuffered
|
||||
|
||||
|
||||
# configs for db
|
||||
class TConfigDB(BaseModel):
|
||||
psm: Text = None
|
||||
user: Text = None
|
||||
password: Text = None
|
||||
ip: Text = None
|
||||
port: int = 3306
|
||||
database: Text = None
|
||||
|
||||
|
||||
class TransportEnum(Text, Enum):
|
||||
BUFFERED = "buffered"
|
||||
FRAMED = "framed"
|
||||
|
||||
|
||||
class TThriftRequest(BaseModel):
|
||||
""" rpc request model"""
|
||||
method: Text = ''
|
||||
params: Dict = {}
|
||||
thrift_client: Any = None
|
||||
idl_path: Text = '' # idl local path
|
||||
timeout: int = 10 # sec
|
||||
transport: TransportEnum = TransportEnum.BUFFERED
|
||||
include_dirs: List[Union[Text, None]] = [] # param of thriftpy2.load
|
||||
target: Text = "" # tcp://{ip}:{port} or sd://psm?cluster=xx&env=xx
|
||||
env: Text = "prod"
|
||||
cluster: Text = "default"
|
||||
psm: Text = ""
|
||||
service_name: Text = None
|
||||
ip: Text = None
|
||||
port: int = None
|
||||
proto_type: ProtoType = None
|
||||
trans_type: TransType = None
|
||||
|
||||
|
||||
class SqlMethodEnum(Text, Enum):
|
||||
FETCHONE = "FETCHONE"
|
||||
FETCHMANY = "FETCHMANY"
|
||||
FETCHALL = "FETCHALL"
|
||||
INSERT = "INSERT"
|
||||
UPDATE = "UPDATE"
|
||||
DELETE = "DELETE"
|
||||
|
||||
|
||||
class TSqlRequest(BaseModel):
|
||||
""" sql request model"""
|
||||
db_config: TConfigDB = TConfigDB()
|
||||
method: SqlMethodEnum = None
|
||||
sql: Text = None
|
||||
size: int = 0 # limit nums of sql result
|
||||
|
||||
|
||||
class TConfig(BaseModel):
|
||||
@@ -51,6 +125,7 @@ class TConfig(BaseModel):
|
||||
path: Text = None
|
||||
# configs for other protocols
|
||||
thrift: TConfigThrift = None
|
||||
db: TConfigDB = TConfigDB()
|
||||
|
||||
|
||||
class TRequest(BaseModel):
|
||||
@@ -84,6 +159,8 @@ class TStep(BaseModel):
|
||||
validate_script: List[Text] = []
|
||||
retry_times: int = 0
|
||||
retry_interval: int = 0 # sec
|
||||
thrift_request: Union[TThriftRequest, None] = None
|
||||
sql_request: Union[TSqlRequest, None] = None
|
||||
|
||||
|
||||
class TestCase(BaseModel):
|
||||
|
||||
@@ -1,18 +1,17 @@
|
||||
from typing import Any, Dict, Text
|
||||
from typing import Dict, Text, Any
|
||||
|
||||
import jmespath
|
||||
import requests
|
||||
from jmespath.exceptions import JMESPathError
|
||||
from loguru import logger
|
||||
|
||||
from httprunner import exceptions
|
||||
from httprunner.exceptions import ParamsError, ValidationFailure
|
||||
from httprunner.models import Validators, VariablesMapping
|
||||
from httprunner.parser import Parser, parse_string_value
|
||||
from httprunner.exceptions import ValidationFailure, ParamsError
|
||||
from httprunner.models import VariablesMapping, Validators
|
||||
from httprunner.parser import parse_string_value, Parser
|
||||
|
||||
|
||||
def get_uniform_comparator(comparator: Text):
|
||||
"""convert comparator alias to uniform name"""
|
||||
""" convert comparator alias to uniform name"""
|
||||
if comparator in ["eq", "equals", "equal"]:
|
||||
return "equal"
|
||||
elif comparator in ["lt", "less_than"]:
|
||||
@@ -113,9 +112,9 @@ def uniform_validator(validator):
|
||||
}
|
||||
|
||||
|
||||
class ResponseObject(object):
|
||||
def __init__(self, resp_obj: requests.Response, parser: Parser):
|
||||
"""initialize with a requests.Response object
|
||||
class ResponseObjectBase(object):
|
||||
def __init__(self, resp_obj, parser: Parser):
|
||||
""" initialize with a response object
|
||||
|
||||
Args:
|
||||
resp_obj (instance): requests.Response instance
|
||||
@@ -125,71 +124,33 @@ class ResponseObject(object):
|
||||
self.parser = parser
|
||||
self.validation_results: Dict = {}
|
||||
|
||||
def __getattr__(self, key):
|
||||
if key in ["json", "content", "body"]:
|
||||
try:
|
||||
value = self.resp_obj.json()
|
||||
except ValueError:
|
||||
value = self.resp_obj.content
|
||||
elif key == "cookies":
|
||||
value = self.resp_obj.cookies.get_dict()
|
||||
else:
|
||||
try:
|
||||
value = getattr(self.resp_obj, key)
|
||||
except AttributeError:
|
||||
err_msg = "ResponseObject does not have attribute: {}".format(key)
|
||||
logger.error(err_msg)
|
||||
raise exceptions.ParamsError(err_msg)
|
||||
|
||||
self.__dict__[key] = value
|
||||
return value
|
||||
|
||||
def _search_jmespath(self, expr: Text) -> Any:
|
||||
resp_obj_meta = {
|
||||
"status_code": self.status_code,
|
||||
"headers": self.headers,
|
||||
"cookies": self.cookies,
|
||||
"body": self.body,
|
||||
}
|
||||
if not expr.startswith(tuple(resp_obj_meta.keys())):
|
||||
return expr
|
||||
|
||||
try:
|
||||
check_value = jmespath.search(expr, resp_obj_meta)
|
||||
except JMESPathError as ex:
|
||||
logger.error(
|
||||
f"failed to search with jmespath\n"
|
||||
f"expression: {expr}\n"
|
||||
f"data: {resp_obj_meta}\n"
|
||||
f"exception: {ex}"
|
||||
)
|
||||
raise
|
||||
|
||||
return check_value
|
||||
|
||||
def extract(
|
||||
self,
|
||||
extractors: Dict[Text, Text],
|
||||
variables_mapping: VariablesMapping = None,
|
||||
) -> Dict[Text, Any]:
|
||||
def extract(self,
|
||||
extractors: Dict[Text, Text],
|
||||
variables_mapping: VariablesMapping = None,
|
||||
) -> Dict[Text, Any]:
|
||||
if not extractors:
|
||||
return {}
|
||||
|
||||
extract_mapping = {}
|
||||
for key, field in extractors.items():
|
||||
if "$" in field:
|
||||
if '$' in field:
|
||||
# field contains variable or function
|
||||
field = self.parser.parse_data(field, variables_mapping)
|
||||
field = self.parser.parse_data(
|
||||
field, variables_mapping
|
||||
)
|
||||
field_value = self._search_jmespath(field)
|
||||
extract_mapping[key] = field_value
|
||||
|
||||
logger.info(f"extract mapping: {extract_mapping}")
|
||||
return extract_mapping
|
||||
|
||||
def _search_jmespath(self, expr: Text) -> Any:
|
||||
raise NotImplementedError("_search_jmespath not override")
|
||||
|
||||
def validate(
|
||||
self,
|
||||
validators: Validators,
|
||||
variables_mapping: VariablesMapping = None,
|
||||
self,
|
||||
validators: Validators,
|
||||
variables_mapping: VariablesMapping = None,
|
||||
):
|
||||
|
||||
variables_mapping = variables_mapping or {}
|
||||
@@ -212,7 +173,9 @@ class ResponseObject(object):
|
||||
check_item = u_validator["check"]
|
||||
if "$" in check_item:
|
||||
# check_item is variable or function
|
||||
check_item = self.parser.parse_data(check_item, variables_mapping)
|
||||
check_item = self.parser.parse_data(
|
||||
check_item, variables_mapping
|
||||
)
|
||||
check_item = parse_string_value(check_item)
|
||||
|
||||
if check_item and isinstance(check_item, Text):
|
||||
@@ -274,3 +237,66 @@ class ResponseObject(object):
|
||||
if not validate_pass:
|
||||
failures_string = "\n".join([failure for failure in failures])
|
||||
raise ValidationFailure(failures_string)
|
||||
|
||||
|
||||
class ResponseObject(ResponseObjectBase):
|
||||
def __getattr__(self, key):
|
||||
if key in ["json", "content", "body"]:
|
||||
try:
|
||||
value = self.resp_obj.json()
|
||||
except ValueError:
|
||||
value = self.resp_obj.content
|
||||
elif key == "cookies":
|
||||
value = self.resp_obj.cookies.get_dict()
|
||||
else:
|
||||
try:
|
||||
value = getattr(self.resp_obj, key)
|
||||
except AttributeError:
|
||||
err_msg = "ResponseObject does not have attribute: {}".format(key)
|
||||
logger.error(err_msg)
|
||||
raise exceptions.ParamsError(err_msg)
|
||||
|
||||
self.__dict__[key] = value
|
||||
return value
|
||||
|
||||
def _search_jmespath(self, expr: Text) -> Any:
|
||||
resp_obj_meta = {
|
||||
"status_code": self.status_code,
|
||||
"headers": self.headers,
|
||||
"cookies": self.cookies,
|
||||
"body": self.body,
|
||||
}
|
||||
if not expr.startswith(tuple(resp_obj_meta.keys())):
|
||||
return expr
|
||||
|
||||
try:
|
||||
check_value = jmespath.search(expr, resp_obj_meta)
|
||||
except JMESPathError as ex:
|
||||
logger.error(
|
||||
f"failed to search with jmespath\n"
|
||||
f"expression: {expr}\n"
|
||||
f"data: {resp_obj_meta}\n"
|
||||
f"exception: {ex}"
|
||||
)
|
||||
raise
|
||||
|
||||
return check_value
|
||||
|
||||
|
||||
class ThriftResponseObject(ResponseObjectBase):
|
||||
def _search_jmespath(self, expr: Text) -> Any:
|
||||
try:
|
||||
check_value = jmespath.search(expr, self.resp_obj)
|
||||
except JMESPathError as ex:
|
||||
logger.error(
|
||||
f"failed to search with jmespath\n"
|
||||
f"expression: {expr}\n"
|
||||
f"data: {self.resp_obj}\n"
|
||||
f"exception: {ex}"
|
||||
)
|
||||
raise
|
||||
return check_value
|
||||
|
||||
|
||||
class SqlResponseObject(ThriftResponseObject):
|
||||
pass
|
||||
|
||||
@@ -38,6 +38,8 @@ class SessionRunner(object):
|
||||
session: HttpSession = None
|
||||
case_id: Text = ""
|
||||
root_dir: Text = ""
|
||||
thrift_client = None
|
||||
db_engine = None
|
||||
|
||||
__config: TConfig
|
||||
__project_meta: ProjectMeta = None
|
||||
@@ -87,6 +89,12 @@ class SessionRunner(object):
|
||||
self.__export = export
|
||||
return self
|
||||
|
||||
def with_thrift_client(self, thrift_client) -> "SessionRunner":
|
||||
self.thrift_client = thrift_client
|
||||
|
||||
def with_db_engine(self,db_engine):
|
||||
self.db_engine = db_engine
|
||||
|
||||
def __parse_config(self, param: Dict = None) -> None:
|
||||
# parse config variables
|
||||
self.__config.variables.update(self.__session_variables)
|
||||
|
||||
@@ -8,6 +8,7 @@ from httprunner.step_request import (
|
||||
StepRequestValidation,
|
||||
)
|
||||
from httprunner.step_testcase import StepRefCase
|
||||
from httprunner.step_sql_request import RunSqlRequest, StepSqlRequestValidation, StepSqlRequestExtraction
|
||||
|
||||
|
||||
class Step(object):
|
||||
@@ -18,6 +19,9 @@ class Step(object):
|
||||
StepRequestExtraction,
|
||||
RequestWithOptionalArgs,
|
||||
StepRefCase,
|
||||
RunSqlRequest,
|
||||
StepSqlRequestValidation,
|
||||
StepSqlRequestExtraction,
|
||||
],
|
||||
):
|
||||
self.__step = step
|
||||
|
||||
244
httprunner/step_sql_request.py
Normal file
244
httprunner/step_sql_request.py
Normal file
@@ -0,0 +1,244 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
import time
|
||||
from typing import Text
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from httprunner import utils
|
||||
from httprunner.exceptions import ValidationFailure
|
||||
from httprunner.models import IStep, StepResult, TStep
|
||||
from httprunner.models import TSqlRequest, SqlMethodEnum
|
||||
from httprunner.response import SqlResponseObject
|
||||
from httprunner.runner import HttpRunner
|
||||
from httprunner.step_request import call_hooks, StepRequestExtraction, StepRequestValidation
|
||||
from httprunner.database.engine import DBEngine
|
||||
from httprunner.exceptions import SqlMethodNotSupport
|
||||
|
||||
|
||||
def run_step_sql_request(runner: HttpRunner, step: TStep) -> StepResult:
|
||||
"""run teststep:sql request"""
|
||||
start_time = time.time()
|
||||
|
||||
step_result = StepResult(
|
||||
name=step.name,
|
||||
success=False,
|
||||
)
|
||||
step.variables = runner.merge_step_variables(step.variables)
|
||||
# parse
|
||||
request_dict = step.sql_request.dict()
|
||||
parsed_request_dict = runner.parser.parse_data(
|
||||
request_dict, step.variables
|
||||
)
|
||||
config = runner.get_config()
|
||||
parsed_request_dict["db_config"]["psm"] = parsed_request_dict["db_config"]["psm"] or config.db.psm
|
||||
parsed_request_dict["db_config"]["user"] = parsed_request_dict["db_config"]["user"] or config.db.user
|
||||
parsed_request_dict["db_config"]["password"] = parsed_request_dict["db_config"]["password"] or config.db.password
|
||||
parsed_request_dict["db_config"]["ip"] = parsed_request_dict["db_config"]["ip"] or config.db.ip
|
||||
parsed_request_dict["db_config"]["port"] = parsed_request_dict["db_config"]["port"] or config.db.port
|
||||
parsed_request_dict["db_config"]["database"] = parsed_request_dict["db_config"]["database"] or config.db.database
|
||||
|
||||
if parsed_request_dict["db_config"]["psm"]:
|
||||
runner.db_engine = DBEngine(f'mysql+pymysql://:@/?charset=utf8mb4&db_psm={parsed_request_dict["psm"]}')
|
||||
else:
|
||||
runner.db_engine = DBEngine(
|
||||
f'mysql+pymysql://{parsed_request_dict["db_config"]["user"]}:'
|
||||
f'{parsed_request_dict["db_config"]["password"]}@{parsed_request_dict["db_config"]["ip"]}:'
|
||||
f'{parsed_request_dict["db_config"]["port"]}/{parsed_request_dict["db_config"]["database"]}'
|
||||
f'?charset=utf8mb4')
|
||||
|
||||
# parsed_request_dict["headers"].setdefault(
|
||||
# "HRUN-Request-ID",
|
||||
# f"HRUN-{self.__case_id}-{str(int(time.time() * 1000))[-6:]}",
|
||||
# )
|
||||
|
||||
# setup hooks
|
||||
if step.setup_hooks:
|
||||
call_hooks(runner, step.setup_hooks, step.variables, "setup request")
|
||||
|
||||
logger.info(f"Executing SQL: {parsed_request_dict['sql']}")
|
||||
if step.sql_request.method == SqlMethodEnum.FETCHONE:
|
||||
sql_resp = runner.db_engine.fetchone(parsed_request_dict['sql'])
|
||||
elif step.sql_request.method == SqlMethodEnum.INSERT:
|
||||
sql_resp = runner.db_engine.insert(parsed_request_dict['sql'])
|
||||
elif step.sql_request.method == SqlMethodEnum.FETCHMANY:
|
||||
sql_resp = runner.db_engine.fetchmany(parsed_request_dict['sql'], parsed_request_dict['size'])
|
||||
elif step.sql_request.method == SqlMethodEnum.FETCHALL:
|
||||
sql_resp = runner.db_engine.fetchall(parsed_request_dict['sql'])
|
||||
elif step.sql_request.method == SqlMethodEnum.UPDATE:
|
||||
sql_resp = runner.db_engine.update(parsed_request_dict['sql'])
|
||||
elif step.sql_request.method == SqlMethodEnum.DELETE:
|
||||
sql_resp = runner.db_engine.delete(parsed_request_dict['sql'])
|
||||
else:
|
||||
raise SqlMethodNotSupport(f"step.sql_request.method {parsed_request_dict['method']} not support")
|
||||
resp_obj = SqlResponseObject(sql_resp, parser=runner.parser)
|
||||
step.variables["sql_response"] = resp_obj
|
||||
|
||||
# teardown hooks
|
||||
if step.teardown_hooks:
|
||||
call_hooks(runner, step.teardown_hooks, step.variables, "teardown request")
|
||||
|
||||
def log_sql_req_resp_details():
|
||||
err_msg = "\n{} SQL DETAILED REQUEST & RESPONSE {}\n".format("*" * 32, "*" * 32)
|
||||
|
||||
# log request
|
||||
err_msg += "====== sql request details ======\n"
|
||||
err_msg += f"sql: {step.sql_request.sql}\n"
|
||||
for k, v in parsed_request_dict.items():
|
||||
v = utils.omit_long_data(v)
|
||||
err_msg += f"{k}: {repr(v)}\n"
|
||||
|
||||
err_msg += "\n"
|
||||
|
||||
# log response
|
||||
err_msg += "====== sql response details ======\n"
|
||||
for k, v in sql_resp.items():
|
||||
v = utils.omit_long_data(v)
|
||||
err_msg += f"{k}: {repr(v)}\n"
|
||||
logger.error(err_msg)
|
||||
|
||||
# extract
|
||||
extractors = step.extract
|
||||
extract_mapping = resp_obj.extract(extractors)
|
||||
step_result.export_vars = extract_mapping
|
||||
|
||||
variables_mapping = step.variables
|
||||
variables_mapping.update(extract_mapping)
|
||||
|
||||
# validate
|
||||
validators = step.validators
|
||||
try:
|
||||
resp_obj.validate(
|
||||
validators, variables_mapping
|
||||
)
|
||||
step_result.success = True
|
||||
except ValidationFailure:
|
||||
log_sql_req_resp_details()
|
||||
raise
|
||||
finally:
|
||||
session_data = runner.session.data
|
||||
session_data.success = step_result.success
|
||||
session_data.validators = resp_obj.validation_results
|
||||
# save step data
|
||||
step_result.data = session_data
|
||||
step_result.elapsed = time.time() - start_time
|
||||
return step_result
|
||||
|
||||
|
||||
class StepSqlRequestValidation(StepRequestValidation):
|
||||
def __init__(self, step: TStep):
|
||||
self.__step = step
|
||||
super().__init__(step)
|
||||
|
||||
def run(self, runner: HttpRunner):
|
||||
return run_step_sql_request(runner, self.__step)
|
||||
|
||||
|
||||
class StepSqlRequestExtraction(StepRequestExtraction):
|
||||
def __init__(self, step: TStep):
|
||||
self.__step = step
|
||||
super().__init__(step)
|
||||
|
||||
def run(self, runner: HttpRunner):
|
||||
return run_step_sql_request(runner, self.__step)
|
||||
|
||||
def validate(self) -> StepSqlRequestValidation:
|
||||
return StepSqlRequestValidation(self.__step)
|
||||
|
||||
|
||||
class RunSqlRequest(IStep):
|
||||
def __init__(self, name: Text):
|
||||
self.__step = TStep(name=name)
|
||||
self.__step.sql_request = TSqlRequest()
|
||||
|
||||
def with_variables(self, **variables) -> "RunSqlRequest":
|
||||
self.__step.variables.update(variables)
|
||||
return self
|
||||
|
||||
def with_db_config(self, psm=None, user=None, password=None, ip=None, port=None, database=None):
|
||||
if psm:
|
||||
self.__step.sql_request.db_config.psm = psm
|
||||
if user:
|
||||
self.__step.sql_request.db_config.user = user
|
||||
if password:
|
||||
self.__step.sql_request.db_config.password = password
|
||||
if ip:
|
||||
self.__step.sql_request.db_config.ip = ip
|
||||
if port:
|
||||
self.__step.sql_request.db_config.port = port
|
||||
if database:
|
||||
self.__step.sql_request.db_config.database = database
|
||||
return self
|
||||
|
||||
def fetchone(self, sql) -> "RunSqlRequest":
|
||||
self.__step.sql_request.method = SqlMethodEnum.FETCHONE
|
||||
self.__step.sql_request.sql = sql
|
||||
return self
|
||||
|
||||
def fetchmany(self, sql, size) -> "RunSqlRequest":
|
||||
self.__step.sql_request.method = SqlMethodEnum.FETCHMANY
|
||||
self.__step.sql_request.sql = sql
|
||||
self.__step.sql_request.size = size
|
||||
return self
|
||||
|
||||
def fetchall(self, sql) -> "RunSqlRequest":
|
||||
self.__step.sql_request.method = SqlMethodEnum.FETCHALL
|
||||
self.__step.sql_request.sql = sql
|
||||
return self
|
||||
|
||||
def update(self, sql) -> "RunSqlRequest":
|
||||
self.__step.sql_request.method = SqlMethodEnum.UPDATE
|
||||
self.__step.sql_request.sql = sql
|
||||
return self
|
||||
|
||||
def delete(self, sql) -> "RunSqlRequest":
|
||||
self.__step.sql_request.method = SqlMethodEnum.DELETE
|
||||
self.__step.sql_request.sql = sql
|
||||
return self
|
||||
|
||||
def insert(self, sql) -> "RunSqlRequest":
|
||||
self.__step.sql_request.method = SqlMethodEnum.INSERT
|
||||
self.__step.sql_request.sql = sql
|
||||
return self
|
||||
|
||||
def with_retry(self, retry_times, retry_interval) -> "RunSqlRequest":
|
||||
self.__step.retry_times = retry_times
|
||||
self.__step.retry_interval = retry_interval
|
||||
return self
|
||||
|
||||
def teardown_hook(self, hook: Text, assign_var_name: Text = None) -> "RunSqlRequest":
|
||||
if assign_var_name:
|
||||
self.__step.teardown_hooks.append({assign_var_name: hook})
|
||||
else:
|
||||
self.__step.teardown_hooks.append(hook)
|
||||
|
||||
return self
|
||||
|
||||
def setup_hook(self, hook: Text, assign_var_name: Text = None) -> "RunSqlRequest":
|
||||
if assign_var_name:
|
||||
self.__step.setup_hooks.append({assign_var_name: hook})
|
||||
else:
|
||||
self.__step.setup_hooks.append(hook)
|
||||
|
||||
return self
|
||||
|
||||
def struct(self) -> TStep:
|
||||
return self.__step
|
||||
|
||||
def name(self) -> Text:
|
||||
return self.__step.name
|
||||
|
||||
def type(self) -> Text:
|
||||
return f"sql-request-{self.__step.sql_request.sql}"
|
||||
|
||||
def run(self, runner) -> StepResult:
|
||||
return run_step_sql_request(runner, self.__step)
|
||||
|
||||
def extract(self) -> StepSqlRequestExtraction:
|
||||
return StepSqlRequestExtraction(self.__step)
|
||||
|
||||
def validate(self) -> StepSqlRequestValidation:
|
||||
return StepSqlRequestValidation(self.__step)
|
||||
|
||||
def with_jmespath(self, jmes_path: Text, var_name: Text) -> "StepSqlRequestExtraction":
|
||||
self.__step.extract[var_name] = jmes_path
|
||||
return StepSqlRequestExtraction(self.__step)
|
||||
Reference in New Issue
Block a user