diff --git a/ate/utils.py b/ate/utils.py index 9a82b068..63746621 100644 --- a/ate/utils.py +++ b/ate/utils.py @@ -65,6 +65,16 @@ def load_folder_files(folder_path, recursive=True): folder_path: specified folder path to load recursive: if True, will load files recursively """ + if isinstance(folder_path, (list, set)): + files = [] + for path in set(folder_path): + files.extend(load_folder_files(path, recursive)) + + return files + + if not os.path.exists(folder_path): + return [] + file_list = [] for dirpath, dirnames, filenames in os.walk(folder_path): diff --git a/tests/test_utils.py b/tests/test_utils.py index 9acd913e..80d42467 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -45,9 +45,21 @@ class TestUtils(ApiServerUnittest): self.assertIn(file2, files) self.assertNotIn(file1, files) - files = utils.load_folder_files(folder) + files_1 = utils.load_folder_files(folder) api_file = os.path.join(os.getcwd(), 'tests', 'api', 'demo.yml') - self.assertEqual(files[0], api_file) + self.assertEqual(files_1[0], api_file) + + folder_list = [folder, folder] + files_2 = utils.load_folder_files(folder) + api_file = os.path.join(os.getcwd(), 'tests', 'api', 'demo.yml') + self.assertEqual(files_2[0], api_file) + self.assertEqual(len(files_1), len(files_2)) + + files = utils.load_folder_files("not_existed_foulder", recursive=False) + self.assertEqual([], files) + + files = utils.load_folder_files(file2, recursive=False) + self.assertEqual([], files) def test_query_json(self): json_content = {