diff --git a/httprunner/loader.py b/httprunner/loader.py index 6f07f601..6c3ada5e 100644 --- a/httprunner/loader.py +++ b/httprunner/loader.py @@ -170,43 +170,49 @@ def load_dot_env_file(path): ## debugtalk.py module loader ############################################################################### -def locate_debugtalk_py(start_dir_path): +def locate_debugtalk_py(start_path): """ locate debugtalk.py module and return module name. + searching will be recursive upward until current working directory. Args: - start_dir_path (str): start locating directory path + start_path (str): start locating path, maybe file path or directory path Returns: str: located module name. None if module not found. - Examples: - # CWD/debugtalk.py - >>> locate_debugtalk_py("/path/to/CWD") - debugtalk - - # CWD/tests/debugtalk.py - >>> locate_debugtalk_py("/path/to/CWD") - tests.debugtalk + Raises: + exceptions.FileNotFound: If failed to locate debugtalk.py module. """ + if os.path.isfile(start_path): + start_dir_path = os.path.dirname(start_path) + elif os.path.isdir(start_path): + start_dir_path = start_path + else: + raise exceptions.FileNotFound("invalid path: {}".format(start_path)) + module_path = os.path.join(start_dir_path, "debugtalk.py") if os.path.isfile(module_path): - return "debugtalk" + if os.path.isabs(module_path): + module_path = module_path[len(os.getcwd())+1:] - # make compatible with former version - # TODO: remove this compatiblity - module_path = os.path.join(start_dir_path, "tests", "debugtalk.py") - if os.path.isfile(module_path): - return "tests.debugtalk" + module_name = module_path.replace("/", ".").rstrip(".py") + return module_name - return None + # current working directory + if os.path.abspath(start_dir_path) == os.getcwd(): + raise exceptions.FileNotFound("debugtalk.py module not found: {}".format(start_path)) + + # locate recursive upward + return locate_debugtalk_py(os.path.dirname(start_dir_path)) -def load_debugtalk_module(module_name=None): +def load_debugtalk_module(start_path=None): """ load debugtalk.py module. Args: - module_name (str, optional): module name for debugtalk.py. Defaults to None. + start_path (str, optional): start locating path, maybe file path or directory path. + Defaults to current working directory. Returns: dict: variables and functions mapping for debugtalk.py @@ -216,34 +222,22 @@ def load_debugtalk_module(module_name=None): "functions": {} } - Examples: - # debugtalk.py - >>> load_debugtalk_module() - debugtalk - - # tests/debugtalk.py - >>> load_debugtalk_module() - tests.debugtalk - - Raises: - exceptions.ParamsError: If failed to import specified module. - """ - module_name = module_name or locate_debugtalk_py(os.getcwd()) - - if not module_name: - return {} + start_path = start_path or os.getcwd() try: - imported_module = importlib.import_module(module_name) - except ImportError: - raise exceptions.ParamsError("module name error: {}".format(module_name)) + module_name = locate_debugtalk_py(start_path) + except exceptions.FileNotFound: + return {} - return { + imported_module = importlib.import_module(module_name) + debugtalk_module = { "variables": utils.filter_module(imported_module, "variable"), "functions": utils.filter_module(imported_module, "function") } + return debugtalk_module + ############################################################################### ## suite loader diff --git a/tests/test_loader.py b/tests/test_loader.py index 102e9a83..2542e42a 100644 --- a/tests/test_loader.py +++ b/tests/test_loader.py @@ -144,23 +144,39 @@ class TestFileLoader(unittest.TestCase): class TestModuleLoader(unittest.TestCase): def test_locate_debugtalk_py(self): - self.assertEqual(loader.locate_debugtalk_py(os.getcwd()), "tests.debugtalk") + with self.assertRaises(exceptions.FileNotFound): + loader.locate_debugtalk_py(os.getcwd()) - start_dir_path = os.path.join(os.getcwd(), "tests") + with self.assertRaises(exceptions.FileNotFound): + loader.locate_debugtalk_py("") + + start_path = os.path.join(os.getcwd(), "tests") self.assertEqual( - loader.locate_debugtalk_py(start_dir_path), - "debugtalk" + loader.locate_debugtalk_py(start_path), + "tests.debugtalk" ) - - start_dir_path = os.path.join(os.getcwd(), "not_exist") self.assertEqual( - loader.locate_debugtalk_py(start_dir_path), - None + loader.locate_debugtalk_py("tests/"), + "tests.debugtalk" + ) + self.assertEqual( + loader.locate_debugtalk_py("tests"), + "tests.debugtalk" + ) + self.assertEqual( + loader.locate_debugtalk_py("tests/base.py"), + "tests.debugtalk" + ) + self.assertEqual( + loader.locate_debugtalk_py("tests/data/test.env"), + "tests.debugtalk" ) def test_load_debugtalk_module(self): - imported_module_items = loader.load_debugtalk_module("tests.debugtalk") - print(imported_module_items) + imported_module_items = loader.load_debugtalk_module() + self.assertEqual(imported_module_items, {}) + + imported_module_items = loader.load_debugtalk_module("tests") self.assertEqual( imported_module_items["variables"]["SECRET_KEY"], "DebugTalk" @@ -171,9 +187,6 @@ class TestModuleLoader(unittest.TestCase): self.assertTrue(is_status_code_200(200)) self.assertFalse(is_status_code_200(500)) - with self.assertRaises(exceptions.ParamsError): - loader.load_debugtalk_module("debugtalk") - class TestSuiteLoader(unittest.TestCase):