diff --git a/joj3_config_generator/loader.py b/joj3_config_generator/loader.py index 6c8b877..cbb3ee3 100644 --- a/joj3_config_generator/loader.py +++ b/joj3_config_generator/loader.py @@ -1,6 +1,6 @@ from importlib import resources from pathlib import Path -from typing import Tuple +from typing import Tuple, Type, cast import inquirer import tomli @@ -11,13 +11,14 @@ from joj3_config_generator.models import answer, joj1, repo, task def load_joj3_task_toml_answers() -> answer.Answers: name = inquirer.text("What's the task name?", default="hw0") - language: answer.LanguageInterface = inquirer.list_input( - "What's the language?", choices=answer.LANGUAGES + language = inquirer.list_input( + "What's the language?", choices=[(cls.name, cls) for cls in answer.LANGUAGES] ) + language = cast(Type[answer.LanguageInterface], language) if inquirer.confirm("Load content from templates?", default=True): answers = inquirer.prompt(language.get_template_questions()) templates_dir = resources.files(f"joj3_config_generator.templates").joinpath( - language.__str__() + language.name ) template_file_path = answers["template_file"] template_file_content = Path(templates_dir / template_file_path).read_text() diff --git a/joj3_config_generator/models/answer.py b/joj3_config_generator/models/answer.py index 706db22..3be9fd7 100644 --- a/joj3_config_generator/models/answer.py +++ b/joj3_config_generator/models/answer.py @@ -1,16 +1,14 @@ from abc import ABC, abstractmethod from enum import Enum from importlib import resources -from typing import Any, ClassVar, Dict, List +from typing import Any, ClassVar, Dict, List, Type import inquirer from pydantic import BaseModel, ConfigDict class LanguageInterface(ABC): - @classmethod - @abstractmethod - def __str__(cls) -> str: ... + name: ClassVar[str] @abstractmethod class Stage(str, Enum): ... @@ -36,7 +34,7 @@ class LanguageInterface(ABC): @classmethod def get_template_questions(cls) -> List[Any]: templates_dir = resources.files(f"joj3_config_generator.templates").joinpath( - cls.__str__() + cls.name ) choices = [] for entry in templates_dir.iterdir(): @@ -52,9 +50,7 @@ class LanguageInterface(ABC): class Cpp(LanguageInterface): - @classmethod - def __str__(cls) -> str: - return "C++" + name = "C++" class Stage(str, Enum): COMPILATION = "Compilation" @@ -72,24 +68,23 @@ class Cpp(LanguageInterface): @classmethod def get_attribute_questions(cls) -> List[Any]: + attribute: Cpp.Attribute = cls.attribute return [ inquirer.Text( name="compile_command", message="Compile command", - default=cls.attribute.compile_command, + default=attribute.compile_command, ), inquirer.Text( name="run_command", message="Run command", - default=cls.attribute.run_command, + default=attribute.run_command, ), ] class Python(LanguageInterface): - @classmethod - def __str__(cls) -> str: - return "Python" + name = "Python" class Stage(str, Enum): RUN = "Run" @@ -102,19 +97,18 @@ class Python(LanguageInterface): @classmethod def get_attribute_questions(cls) -> List[Any]: + attribute: Python.Attribute = cls.attribute return [ inquirer.Text( name="run_command", message="Run command", - default=cls.attribute.run_command, + default=attribute.run_command, ), ] class Rust(LanguageInterface): - @classmethod - def __str__(cls) -> str: - return "Rust" + name = "Rust" class Stage(str, Enum): COMPILATION = "Compilation" @@ -129,19 +123,20 @@ class Rust(LanguageInterface): @classmethod def get_attribute_questions(cls) -> List[Any]: + attribute: Rust.Attribute = cls.attribute return [] -LANGUAGES = [ - Cpp(), - Python(), - Rust(), +LANGUAGES: List[Type[LanguageInterface]] = [ + Cpp, + Python, + Rust, ] class Answers(BaseModel): name: str - language: LanguageInterface + language: Type[LanguageInterface] template_file_content: str = "" model_config = ConfigDict(arbitrary_types_allowed=True) diff --git a/joj3_config_generator/transformers/answer.py b/joj3_config_generator/transformers/answer.py index b0a9917..cdfc29b 100644 --- a/joj3_config_generator/transformers/answer.py +++ b/joj3_config_generator/transformers/answer.py @@ -14,14 +14,14 @@ def get_task_conf_from_answers(answers: answer.Answers) -> task.Config: ) language = answers.language transformer_dict = get_transformer_dict() - transformer = transformer_dict[type(language)] + transformer = transformer_dict[language] stages = transformer(language) return task.Config(task=task.Task(name=answers.name), stages=stages) def get_transformer_dict() -> Dict[ - Type[Any], - Callable[[Any], List[task.Stage]], + Type[answer.LanguageInterface], + Callable[[Type[Any]], List[task.Stage]], ]: return { answer.Cpp: get_cpp_stages, @@ -31,7 +31,7 @@ def get_transformer_dict() -> Dict[ # TODO: implement -def get_cpp_stages(language: answer.Cpp) -> List[task.Stage]: +def get_cpp_stages(language: Type[answer.Cpp]) -> List[task.Stage]: stages = language.stages attribute: answer.Cpp.Attribute = language.attribute task_stages = [] @@ -47,14 +47,14 @@ def get_cpp_stages(language: answer.Cpp) -> List[task.Stage]: # TODO: implement -def get_python_stages(language: answer.Python) -> List[task.Stage]: +def get_python_stages(language: Type[answer.Python]) -> List[task.Stage]: stages = language.stages attribute: answer.Python.Attribute = language.attribute return [] # TODO: implement -def get_rust_stages(language: answer.Rust) -> List[task.Stage]: +def get_rust_stages(language: Type[answer.Rust]) -> List[task.Stage]: stages = language.stages attribute: answer.Rust.Attribute = language.attribute return [] diff --git a/tests/create/utils.py b/tests/create/utils.py index 45fe305..d4d3bba 100644 --- a/tests/create/utils.py +++ b/tests/create/utils.py @@ -13,7 +13,7 @@ def load_case(case_name: str) -> None: answers_json_path = root / case_name / "answers.json" task_toml_path = root / case_name / "task.toml" answers_dict = json.loads(answers_json_path.read_text()) - language = next(x for x in answer.LANGUAGES if str(x) == answers_dict["language"]) + language = next(x for x in answer.LANGUAGES if x.name == answers_dict["language"]) language.set_stages(answers_dict["stages"]) language.set_attribute(answers_dict["attribute"]) answers = answer.Answers(name=answers_dict["name"], language=language)