refactor load debugtalk.py module

This commit is contained in:
debugtalk
2018-08-08 18:23:42 +08:00
parent 33120fb183
commit a33384bb37
2 changed files with 59 additions and 52 deletions

View File

@@ -170,43 +170,49 @@ def load_dot_env_file(path):
## debugtalk.py module loader
###############################################################################
def locate_debugtalk_py(start_dir_path):
def locate_debugtalk_py(start_path):
""" locate debugtalk.py module and return module name.
searching will be recursive upward until current working directory.
Args:
start_dir_path (str): start locating directory path
start_path (str): start locating path, maybe file path or directory path
Returns:
str: located module name. None if module not found.
Examples:
# CWD/debugtalk.py
>>> locate_debugtalk_py("/path/to/CWD")
debugtalk
# CWD/tests/debugtalk.py
>>> locate_debugtalk_py("/path/to/CWD")
tests.debugtalk
Raises:
exceptions.FileNotFound: If failed to locate debugtalk.py module.
"""
if os.path.isfile(start_path):
start_dir_path = os.path.dirname(start_path)
elif os.path.isdir(start_path):
start_dir_path = start_path
else:
raise exceptions.FileNotFound("invalid path: {}".format(start_path))
module_path = os.path.join(start_dir_path, "debugtalk.py")
if os.path.isfile(module_path):
return "debugtalk"
if os.path.isabs(module_path):
module_path = module_path[len(os.getcwd())+1:]
# make compatible with former version
# TODO: remove this compatiblity
module_path = os.path.join(start_dir_path, "tests", "debugtalk.py")
if os.path.isfile(module_path):
return "tests.debugtalk"
module_name = module_path.replace("/", ".").rstrip(".py")
return module_name
return None
# current working directory
if os.path.abspath(start_dir_path) == os.getcwd():
raise exceptions.FileNotFound("debugtalk.py module not found: {}".format(start_path))
# locate recursive upward
return locate_debugtalk_py(os.path.dirname(start_dir_path))
def load_debugtalk_module(module_name=None):
def load_debugtalk_module(start_path=None):
""" load debugtalk.py module.
Args:
module_name (str, optional): module name for debugtalk.py. Defaults to None.
start_path (str, optional): start locating path, maybe file path or directory path.
Defaults to current working directory.
Returns:
dict: variables and functions mapping for debugtalk.py
@@ -216,34 +222,22 @@ def load_debugtalk_module(module_name=None):
"functions": {}
}
Examples:
# debugtalk.py
>>> load_debugtalk_module()
debugtalk
# tests/debugtalk.py
>>> load_debugtalk_module()
tests.debugtalk
Raises:
exceptions.ParamsError: If failed to import specified module.
"""
module_name = module_name or locate_debugtalk_py(os.getcwd())
if not module_name:
return {}
start_path = start_path or os.getcwd()
try:
imported_module = importlib.import_module(module_name)
except ImportError:
raise exceptions.ParamsError("module name error: {}".format(module_name))
module_name = locate_debugtalk_py(start_path)
except exceptions.FileNotFound:
return {}
return {
imported_module = importlib.import_module(module_name)
debugtalk_module = {
"variables": utils.filter_module(imported_module, "variable"),
"functions": utils.filter_module(imported_module, "function")
}
return debugtalk_module
###############################################################################
## suite loader

View File

@@ -144,23 +144,39 @@ class TestFileLoader(unittest.TestCase):
class TestModuleLoader(unittest.TestCase):
def test_locate_debugtalk_py(self):
self.assertEqual(loader.locate_debugtalk_py(os.getcwd()), "tests.debugtalk")
with self.assertRaises(exceptions.FileNotFound):
loader.locate_debugtalk_py(os.getcwd())
start_dir_path = os.path.join(os.getcwd(), "tests")
with self.assertRaises(exceptions.FileNotFound):
loader.locate_debugtalk_py("")
start_path = os.path.join(os.getcwd(), "tests")
self.assertEqual(
loader.locate_debugtalk_py(start_dir_path),
"debugtalk"
loader.locate_debugtalk_py(start_path),
"tests.debugtalk"
)
start_dir_path = os.path.join(os.getcwd(), "not_exist")
self.assertEqual(
loader.locate_debugtalk_py(start_dir_path),
None
loader.locate_debugtalk_py("tests/"),
"tests.debugtalk"
)
self.assertEqual(
loader.locate_debugtalk_py("tests"),
"tests.debugtalk"
)
self.assertEqual(
loader.locate_debugtalk_py("tests/base.py"),
"tests.debugtalk"
)
self.assertEqual(
loader.locate_debugtalk_py("tests/data/test.env"),
"tests.debugtalk"
)
def test_load_debugtalk_module(self):
imported_module_items = loader.load_debugtalk_module("tests.debugtalk")
print(imported_module_items)
imported_module_items = loader.load_debugtalk_module()
self.assertEqual(imported_module_items, {})
imported_module_items = loader.load_debugtalk_module("tests")
self.assertEqual(
imported_module_items["variables"]["SECRET_KEY"],
"DebugTalk"
@@ -171,9 +187,6 @@ class TestModuleLoader(unittest.TestCase):
self.assertTrue(is_status_code_200(200))
self.assertFalse(is_status_code_200(500))
with self.assertRaises(exceptions.ParamsError):
loader.load_debugtalk_module("debugtalk")
class TestSuiteLoader(unittest.TestCase):