diff --git a/ate/context.py b/ate/context.py index 2665754d..aeef98db 100644 --- a/ate/context.py +++ b/ate/context.py @@ -1,7 +1,17 @@ -import re import importlib +import re +import types + from ate import exception, utils + +def is_function(tup): + """ + Takes (name, object) tuple, returns True if it is a function. + """ + name, item = tup + return isinstance(item, types.FunctionType) + class Context(object): """ Manages binding of variables """ @@ -29,6 +39,14 @@ class Context(object): function = eval(function) self.functions[func_name] = function + def import_module_functions(self, modules): + """ import modules and bind all functions within the context + """ + for module_name in modules: + imported = importlib.import_module(module_name) + imported_functions_dict = dict(filter(is_function, vars(imported).items())) + self.functions.update(imported_functions_dict) + def bind_variables(self, variable_binds): """ Bind named variables to value within the context. This allows for passing in variables or functions. diff --git a/ate/runner.py b/ate/runner.py index 3b6dea70..02814f62 100644 --- a/ate/runner.py +++ b/ate/runner.py @@ -41,6 +41,9 @@ class TestRunner(object): function_binds = config_dict.get('function_binds', {}) self.context.bind_functions(function_binds) + module_functions = config_dict.get('import_module_functions', []) + self.context.import_module_functions(module_functions) + variable_binds = config_dict.get('variable_binds', []) self.context.bind_variables(variable_binds) diff --git a/test/data/__init__.py b/test/data/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/test/data/custom_functions.py b/test/data/custom_functions.py new file mode 100644 index 00000000..10e098b9 --- /dev/null +++ b/test/data/custom_functions.py @@ -0,0 +1,42 @@ +import hashlib +import json +import random +import string + +try: + string_type = basestring + PYTHON_VERSION = 2 +except NameError: + string_type = str + PYTHON_VERSION = 3 + + +def gen_random_string(str_len): + return ''.join( + random.choice(string.ascii_letters + string.digits) for _ in range(str_len)) + +def gen_md5(*args): + args = [handle_req_data(item) for item in args] + return hashlib.md5("".join(args).encode('utf-8')).hexdigest() + +def handle_req_data(data): + + if PYTHON_VERSION == 3 and isinstance(data, bytes): + # In Python3, convert bytes to str + data = data.decode('utf-8') + + if not data: + return data + + if isinstance(data, str): + # check if data in str can be converted to dict + try: + data = json.loads(data) + except ValueError: + pass + + if isinstance(data, dict): + # sort data in dict with keys, then convert to str + data = json.dumps(data, sort_keys=True) + + return data diff --git a/test/data/demo_import_functions.yml b/test/data/demo_import_functions.yml new file mode 100644 index 00000000..e48b5f69 --- /dev/null +++ b/test/data/demo_import_functions.yml @@ -0,0 +1,41 @@ +- config: + name: "create user testsets." + import_module_functions: + - test.data.custom_functions + variable_binds: + - TOKEN: debugtalk + - json: {"name": "user", "password": "123456"} + - random: {"func": "gen_random_string", "args": [5]} + - authorization: {"func": "gen_md5", "args": ["${TOKEN}", "${json}", "${random}"]} + +- test: + name: create user which does not exist + variable_binds: + - json: {"name": "user", "password": "123456"} + request: + url: http://127.0.0.1:5000/api/users/1000 + method: POST + headers: + Content-Type: application/json + authorization: "${authorization}" + random: "${random}" + json: "${json}" + validators: + - {"check": "status_code", "comparator": "eq", "expected": 201} + - {"check": "content.success", "comparator": "eq", "expected": true} + +- test: + name: create user which does not exist + variable_binds: + - json: {"name": "user", "password": "123456"} + request: + url: http://127.0.0.1:5000/api/users/1000 + method: POST + headers: + Content-Type: application/json + authorization: "${authorization}" + random: "${random}" + json: "${json}" + validators: + - {"check": "status_code", "comparator": "eq", "expected": 500} + - {"check": "content.success", "comparator": "eq", "expected": false} diff --git a/test/test_runner_v2.py b/test/test_runner_v2.py index d7ae1b63..8cb90118 100644 --- a/test/test_runner_v2.py +++ b/test/test_runner_v2.py @@ -87,3 +87,19 @@ class TestRunnerV2(ApiServerUnittest): results = self.test_runner.run_testsets(testsets) self.assertEqual(len(results), 1) self.assertEqual(results[0], [(True, []), (True, [])]) + + def test_run_testset_template_import_functions(self): + testcase_file_path = os.path.join( + os.getcwd(), 'test/data/demo_import_functions.yml') + testsets = utils.load_testcases_by_path(testcase_file_path) + results = self.test_runner.run_testset(testsets[0]) + self.assertEqual(len(results), 2) + self.assertEqual(results, [(True, []), (True, [])]) + + def test_run_testsets_template_import_functions(self): + testcase_file_path = os.path.join( + os.getcwd(), 'test/data/demo_import_functions.yml') + testsets = utils.load_testcases_by_path(testcase_file_path) + results = self.test_runner.run_testsets(testsets) + self.assertEqual(len(results), 1) + self.assertEqual(results[0], [(True, []), (True, [])]) diff --git a/testcases/__init__.py b/testcases/__init__.py new file mode 100644 index 00000000..e69de29b