diff --git a/ate/exception.py b/ate/exception.py index 53b91138..92df5887 100644 --- a/ate/exception.py +++ b/ate/exception.py @@ -1,4 +1,8 @@ #coding: utf-8 +try: + FileNotFoundError = FileNotFoundError +except NameError: + FileNotFoundError = IOError class MyBaseError(BaseException): pass @@ -14,3 +18,6 @@ class ParseResponseError(MyBaseError): class ValidationError(MyBaseError): pass + +class FunctionNotFound(NameError): + pass diff --git a/ate/utils.py b/ate/utils.py index b795984e..4eb1eb50 100644 --- a/ate/utils.py +++ b/ate/utils.py @@ -2,6 +2,7 @@ import codecs import fnmatch import hashlib import hmac +import imp import importlib import json import os.path @@ -260,8 +261,42 @@ def get_imported_module(module_name): """ return importlib.import_module(module_name) +def get_imported_module_from_file(file_path): + """ import module from python file path and return imported module + """ + + if PYTHON_VERSION == 3: + imported_module = importlib.machinery.SourceFileLoader( + 'module_name', file_path).load_module() + else: + # Python 2.7 + imported_module = imp.load_source('module_name', file_path) + + return imported_module + def filter_module_functions(module): """ filter functions from import module """ module_functions_dict = dict(filter(is_function, vars(module).items())) return module_functions_dict + +def search_conf_function(start_path, func): + """ search expected function recursive upward + """ + dir_path = os.path.dirname(os.path.abspath(start_path)) + target_file = os.path.join(dir_path, "debugtalk.py") + + if os.path.isfile(target_file): + imported_module = get_imported_module_from_file(target_file) + functions_dict = filter_module_functions(imported_module) + if func in functions_dict: + return functions_dict[func] + else: + return search_conf_function(dir_path, func) + + if dir_path == start_path: + # system root path + err_msg = "{} not found in recursive upward path!".format(func) + raise exception.FunctionNotFound(err_msg) + + return search_conf_function(dir_path, func) diff --git a/tests/test_utils.py b/tests/test_utils.py index d033dbdb..697e75c9 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -225,3 +225,41 @@ class TestUtils(ApiServerUnittest): updated_dict, {'a': 2, 'b': {'c': 33, 'd': 4, 'e': 5}, 'f': 6, 'g': 7} ) + + def test_get_imported_module(self): + imported_module = utils.get_imported_module("os") + self.assertIn("walk", dir(imported_module)) + + def test_filter_module_functions(self): + imported_module = utils.get_imported_module("ate.utils") + self.assertIn("PYTHON_VERSION", dir(imported_module)) + + functions_dict = utils.filter_module_functions(imported_module) + self.assertIn("filter_module_functions", functions_dict) + self.assertNotIn("PYTHON_VERSION", functions_dict) + + def test_get_imported_module_from_file(self): + imported_module = utils.get_imported_module_from_file("tests/data/debugtalk.py") + self.assertIn("gen_md5", dir(imported_module)) + + functions_dict = utils.filter_module_functions(imported_module) + self.assertIn("gen_md5", functions_dict) + self.assertNotIn("PYTHON_VERSION", functions_dict) + + with self.assertRaises(exception.FileNotFoundError): + utils.get_imported_module_from_file("tests/data/debugtalk2.py") + + def test_search_conf_function(self): + gen_md5 = utils.search_conf_function("tests/data/demo_binds.yml", "gen_md5") + self.assertTrue(utils.is_function(("gen_md5", gen_md5))) + self.assertEqual(gen_md5("abc"), "900150983cd24fb0d6963f7d28e17f72") + + gen_md5 = utils.search_conf_function("tests/data/subfolder/test.yml", "gen_md5") + self.assertTrue(utils.is_function(("_", gen_md5))) + self.assertEqual(gen_md5("abc"), "900150983cd24fb0d6963f7d28e17f72") + + with self.assertRaises(exception.FunctionNotFound): + utils.search_conf_function("tests/data/subfolder/test.yml", "func_not_exist") + + with self.assertRaises(exception.FunctionNotFound): + utils.search_conf_function("/user/local/bin", "gen_md5")