From 48d01045b53ecf74e72cfa2160ca87d9d1b8de40 Mon Sep 17 00:00:00 2001 From: Boming Zhang Date: Wed, 23 Oct 2024 01:39:36 -0400 Subject: [PATCH] refactor: more compact test codes --- tests/convert/test_convert_cases.py | 7 ++----- tests/convert/utils.py | 19 +++++++++++++------ tests/convert_joj1/test_convert_joj1_cases.py | 7 ++----- tests/convert_joj1/utils.py | 9 ++++++++- 4 files changed, 25 insertions(+), 17 deletions(-) diff --git a/tests/convert/test_convert_cases.py b/tests/convert/test_convert_cases.py index d32279d..1a3392a 100644 --- a/tests/convert/test_convert_cases.py +++ b/tests/convert/test_convert_cases.py @@ -1,8 +1,5 @@ -from joj3_config_generator.convert import convert -from tests.convert.utils import read_convert_files +from tests.convert.utils import load_case def test_basic() -> None: - repo, task, expected_result = read_convert_files("basic") - result = convert(repo, task).model_dump(by_alias=True) - assert result == expected_result + load_case("basic") diff --git a/tests/convert/utils.py b/tests/convert/utils.py index 519fda2..5875138 100644 --- a/tests/convert/utils.py +++ b/tests/convert/utils.py @@ -4,6 +4,7 @@ from typing import Any, Dict, Tuple import rtoml +from joj3_config_generator.convert import convert from joj3_config_generator.models import repo, task @@ -12,14 +13,20 @@ def read_convert_files( ) -> Tuple[repo.Config, task.Config, Dict[str, Any]]: root = os.path.dirname(os.path.realpath(__file__)) repo_toml_path = os.path.join(root, case_name, "repo.toml") + with open(repo_toml_path) as f: + repo_toml = f.read() task_toml_path = os.path.join(root, case_name, "task.toml") + with open(task_toml_path) as f: + task_toml = f.read() result_json_path = os.path.join(root, case_name, "task.json") - with open(repo_toml_path) as repo_file: - repo_toml = repo_file.read() - with open(task_toml_path) as task_file: - task_toml = task_file.read() - with open(result_json_path) as result_file: - result: Dict[str, Any] = json.load(result_file) + with open(result_json_path) as f: + result: Dict[str, Any] = json.load(f) repo_obj = rtoml.loads(repo_toml) task_obj = rtoml.loads(task_toml) return repo.Config(**repo_obj), task.Config(**task_obj), result + + +def load_case(case_name: str) -> None: + repo, task, expected_result = read_convert_files(case_name) + result = convert(repo, task).model_dump(mode="json", by_alias=True) + assert result == expected_result diff --git a/tests/convert_joj1/test_convert_joj1_cases.py b/tests/convert_joj1/test_convert_joj1_cases.py index 8c0452a..1307e39 100644 --- a/tests/convert_joj1/test_convert_joj1_cases.py +++ b/tests/convert_joj1/test_convert_joj1_cases.py @@ -1,11 +1,8 @@ import pytest -from joj3_config_generator.convert import convert_joj1 -from tests.convert_joj1.utils import read_convert_joj1_files +from tests.convert_joj1.utils import load_case @pytest.mark.xfail def test_basic() -> None: - joj1, expected_result = read_convert_joj1_files("basic") - result = convert_joj1(joj1).model_dump(by_alias=True) - assert result == expected_result + load_case("basic") diff --git a/tests/convert_joj1/utils.py b/tests/convert_joj1/utils.py index 10cc8d2..36f6108 100644 --- a/tests/convert_joj1/utils.py +++ b/tests/convert_joj1/utils.py @@ -4,17 +4,24 @@ from typing import Any, Dict, Tuple import rtoml import yaml +from joj3_config_generator.convert import convert_joj1 from joj3_config_generator.models import joj1 def read_convert_joj1_files(case_name: str) -> Tuple[joj1.Config, Dict[str, Any]]: root = os.path.dirname(os.path.realpath(__file__)) task_yaml_path = os.path.join(root, case_name, "task.yaml") - task_toml_path = os.path.join(root, case_name, "task.toml") with open(task_yaml_path) as f: task_yaml = f.read() + task_toml_path = os.path.join(root, case_name, "task.toml") with open(task_toml_path) as f: task_toml = f.read() joj1_obj = yaml.safe_load(task_yaml) task_obj = rtoml.loads(task_toml) return joj1.Config(**joj1_obj), task_obj + + +def load_case(case_name: str) -> None: + joj1, expected_result = read_convert_joj1_files(case_name) + result = convert_joj1(joj1).model_dump(by_alias=True) + assert result == expected_result