diff --git a/httprunner/loader.py b/httprunner/loader.py index 2203da48..985a8592 100644 --- a/httprunner/loader.py +++ b/httprunner/loader.py @@ -1,11 +1,12 @@ import collections import csv +import importlib import io import json import os import yaml -from httprunner import exceptions, logger, parser, validator +from httprunner import exceptions, logger, parser, utils, validator from httprunner.compat import OrderedDict ############################################################################### @@ -165,6 +166,50 @@ def load_dot_env_file(path): return env_variables_mapping +############################################################################### +## debugtalk.py module loader +############################################################################### + +def locate_debugtalk_py(start_dir_path): + """ locate debugtalk.py module and return module name + e.g. + debugtalk.py => "debugtalk" + tests/debugtalk.py => "tests.debugtalk" + """ + module_path = os.path.join(start_dir_path, "debugtalk.py") + if os.path.isfile(module_path): + return "debugtalk" + + # make compatible with former version + module_path = os.path.join(start_dir_path, "tests", "debugtalk.py") + if os.path.isfile(module_path): + return "tests.debugtalk" + + return None + + +def load_debugtalk_module(module_name=None): + """ load debugtalk.py module + @param (str) module_name + e.g. debugtalk + tests.debugtalk + """ + module_name = module_name or locate_debugtalk_py(os.getcwd()) + + if not module_name: + return {} + + try: + imported_module = importlib.import_module(module_name) + except ImportError: + raise exceptions.ParamsError("module name error: {}".format(module_name)) + + return { + "variables": utils.filter_module(imported_module, "variable"), + "functions": utils.filter_module(imported_module, "function") + } + + ############################################################################### ## suite loader ############################################################################### diff --git a/tests/test_loader.py b/tests/test_loader.py index 412c9df4..102e9a83 100644 --- a/tests/test_loader.py +++ b/tests/test_loader.py @@ -141,6 +141,40 @@ class TestFileLoader(unittest.TestCase): loader.load_dot_env_file("not_exist.env") +class TestModuleLoader(unittest.TestCase): + + def test_locate_debugtalk_py(self): + self.assertEqual(loader.locate_debugtalk_py(os.getcwd()), "tests.debugtalk") + + start_dir_path = os.path.join(os.getcwd(), "tests") + self.assertEqual( + loader.locate_debugtalk_py(start_dir_path), + "debugtalk" + ) + + start_dir_path = os.path.join(os.getcwd(), "not_exist") + self.assertEqual( + loader.locate_debugtalk_py(start_dir_path), + None + ) + + def test_load_debugtalk_module(self): + imported_module_items = loader.load_debugtalk_module("tests.debugtalk") + print(imported_module_items) + self.assertEqual( + imported_module_items["variables"]["SECRET_KEY"], + "DebugTalk" + ) + self.assertIn("alter_response", imported_module_items["functions"]) + + is_status_code_200 = imported_module_items["functions"]["is_status_code_200"] + self.assertTrue(is_status_code_200(200)) + self.assertFalse(is_status_code_200(500)) + + with self.assertRaises(exceptions.ParamsError): + loader.load_debugtalk_module("debugtalk") + + class TestSuiteLoader(unittest.TestCase): def setUp(self):