From 83ee357adeb918b6780316f0c4c0c02209a65848 Mon Sep 17 00:00:00 2001 From: debugtalk Date: Fri, 1 Sep 2017 15:18:34 +0800 Subject: [PATCH] search and filter module variables --- ate/context.py | 2 +- ate/exception.py | 3 +++ ate/testcase.py | 20 ++++++++++++------- ate/utils.py | 47 +++++++++++++++++++++++++++++++++------------ tests/test_utils.py | 29 +++++++++++++++++++++------- 5 files changed, 74 insertions(+), 27 deletions(-) diff --git a/ate/context.py b/ate/context.py index fd5394b8..e9b86f73 100644 --- a/ate/context.py +++ b/ate/context.py @@ -70,7 +70,7 @@ class Context(object): sys.path.insert(0, os.getcwd()) for module_name in modules: imported_module = utils.get_imported_module(module_name) - imported_functions_dict = utils.filter_module_functions(imported_module) + imported_functions_dict = utils.filter_module(imported_module, "function") self.__update_context_functions_config(level, imported_functions_dict) def bind_variables(self, variable_binds, level="testcase"): diff --git a/ate/exception.py b/ate/exception.py index 92df5887..fa1cc24f 100644 --- a/ate/exception.py +++ b/ate/exception.py @@ -21,3 +21,6 @@ class ValidationError(MyBaseError): class FunctionNotFound(NameError): pass + +class VariableNotFound(NameError): + pass diff --git a/ate/testcase.py b/ate/testcase.py index 2c48e8b3..2ed5060a 100644 --- a/ate/testcase.py +++ b/ate/testcase.py @@ -150,17 +150,23 @@ class TestcaseParser(object): """ self.functions_binds = functions_binds - def get_bind_fuctions(self, func_name): - func = self.functions_binds.get(func_name) - if func: - return func + def get_bind_item(self, item_type, item_name): + if item_type == "function": + item = self.functions_binds.get(item_name) + elif item_type == "variable": + item = self.variables_binds.get(item_name) + else: + raise exception.ParamsError("bind item should only be function or variable.") + + if item: + return item try: assert self.file_path is not None - return utils.search_conf_function(self.file_path, func_name) + return utils.search_conf_item(self.file_path, item_type, item_name) except (AssertionError, exception.FunctionNotFound): raise exception.ParamsError( - "%s is not defined in bind functions!" % func_name) + "{} is not defined in bind {}s!".format(item_name, item_type)) def eval_content_functions(self, content): functions_list = extract_functions(content) @@ -168,7 +174,7 @@ class TestcaseParser(object): function_meta = parse_function(func_content) func_name = function_meta['func_name'] - func = self.get_bind_fuctions(func_name) + func = self.get_bind_item("function", func_name) args = function_meta.get('args', []) kwargs = function_meta.get('kwargs', {}) diff --git a/ate/utils.py b/ate/utils.py index 4eb1eb50..2a5a339c 100644 --- a/ate/utils.py +++ b/ate/utils.py @@ -256,6 +256,18 @@ def is_function(tup): name, item = tup return isinstance(item, types.FunctionType) +def is_variable(tup): + """ Takes (name, object) tuple, returns True if it is a variable. + """ + name, item = tup + if callable(item): + return False + + if name.startswith("__"): + return False + + return True + def get_imported_module(module_name): """ import module and return imported module """ @@ -274,29 +286,40 @@ def get_imported_module_from_file(file_path): return imported_module -def filter_module_functions(module): - """ filter functions from import module +def filter_module(module, filter_type): + """ filter functions or variables from import module + @params + module: imported module + filter_type: "function" or "variable" """ - module_functions_dict = dict(filter(is_function, vars(module).items())) + filter_type = is_function if filter_type == "function" else is_variable + module_functions_dict = dict(filter(filter_type, vars(module).items())) return module_functions_dict -def search_conf_function(start_path, func): - """ search expected function recursive upward +def search_conf_item(start_path, item_type, item_name): + """ search expected function or variable recursive upward + @param + start_path: search start path + item_type: "function" or "variable" + item_name: function name or variable name """ dir_path = os.path.dirname(os.path.abspath(start_path)) target_file = os.path.join(dir_path, "debugtalk.py") if os.path.isfile(target_file): imported_module = get_imported_module_from_file(target_file) - functions_dict = filter_module_functions(imported_module) - if func in functions_dict: - return functions_dict[func] + functions_dict = filter_module(imported_module, item_type) + if item_name in functions_dict: + return functions_dict[item_name] else: - return search_conf_function(dir_path, func) + return search_conf_item(dir_path, item_type, item_name) if dir_path == start_path: # system root path - err_msg = "{} not found in recursive upward path!".format(func) - raise exception.FunctionNotFound(err_msg) + err_msg = "{} not found in recursive upward path!".format(item_name) + if item_type == "function": + raise exception.FunctionNotFound(err_msg) + else: + raise exception.VariableNotFound(err_msg) - return search_conf_function(dir_path, func) + return search_conf_item(dir_path, item_type, item_name) diff --git a/tests/test_utils.py b/tests/test_utils.py index 697e75c9..655338d1 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -234,15 +234,15 @@ class TestUtils(ApiServerUnittest): imported_module = utils.get_imported_module("ate.utils") self.assertIn("PYTHON_VERSION", dir(imported_module)) - functions_dict = utils.filter_module_functions(imported_module) - self.assertIn("filter_module_functions", functions_dict) + functions_dict = utils.filter_module(imported_module, "function") + self.assertIn("filter_module", functions_dict) self.assertNotIn("PYTHON_VERSION", functions_dict) def test_get_imported_module_from_file(self): imported_module = utils.get_imported_module_from_file("tests/data/debugtalk.py") self.assertIn("gen_md5", dir(imported_module)) - functions_dict = utils.filter_module_functions(imported_module) + functions_dict = utils.filter_module(imported_module, "function") self.assertIn("gen_md5", functions_dict) self.assertNotIn("PYTHON_VERSION", functions_dict) @@ -250,16 +250,31 @@ class TestUtils(ApiServerUnittest): utils.get_imported_module_from_file("tests/data/debugtalk2.py") def test_search_conf_function(self): - gen_md5 = utils.search_conf_function("tests/data/demo_binds.yml", "gen_md5") + gen_md5 = utils.search_conf_item("tests/data/demo_binds.yml", "function", "gen_md5") self.assertTrue(utils.is_function(("gen_md5", gen_md5))) self.assertEqual(gen_md5("abc"), "900150983cd24fb0d6963f7d28e17f72") - gen_md5 = utils.search_conf_function("tests/data/subfolder/test.yml", "gen_md5") + gen_md5 = utils.search_conf_item("tests/data/subfolder/test.yml", "function", "gen_md5") self.assertTrue(utils.is_function(("_", gen_md5))) self.assertEqual(gen_md5("abc"), "900150983cd24fb0d6963f7d28e17f72") with self.assertRaises(exception.FunctionNotFound): - utils.search_conf_function("tests/data/subfolder/test.yml", "func_not_exist") + utils.search_conf_item("tests/data/subfolder/test.yml", "function", "func_not_exist") with self.assertRaises(exception.FunctionNotFound): - utils.search_conf_function("/user/local/bin", "gen_md5") + utils.search_conf_item("/user/local/bin", "function", "gen_md5") + + def test_search_conf_variable(self): + SECRET_KEY = utils.search_conf_item("tests/data/demo_binds.yml", "variable", "SECRET_KEY") + self.assertTrue(utils.is_variable(("SECRET_KEY", SECRET_KEY))) + self.assertEqual(SECRET_KEY, "DebugTalk") + + SECRET_KEY = utils.search_conf_item("tests/data/subfolder/test.yml", "variable", "SECRET_KEY") + self.assertTrue(utils.is_variable(("SECRET_KEY", SECRET_KEY))) + self.assertEqual(SECRET_KEY, "DebugTalk") + + with self.assertRaises(exception.VariableNotFound): + utils.search_conf_item("tests/data/subfolder/test.yml", "variable", "variable_not_exist") + + with self.assertRaises(exception.VariableNotFound): + utils.search_conf_item("/user/local/bin", "variable", "SECRET_KEY")