style: better types
All checks were successful
build / build (push) Successful in 2m2s

This commit is contained in:
张泊明518370910136 2025-03-19 09:19:58 -04:00
parent 19f5c7193f
commit 160f16ca2b
GPG Key ID: D47306D7062CDA9D
4 changed files with 29 additions and 33 deletions

View File

@ -1,6 +1,6 @@
from importlib import resources from importlib import resources
from pathlib import Path from pathlib import Path
from typing import Tuple from typing import Tuple, Type, cast
import inquirer import inquirer
import tomli import tomli
@ -11,13 +11,14 @@ from joj3_config_generator.models import answer, joj1, repo, task
def load_joj3_task_toml_answers() -> answer.Answers: def load_joj3_task_toml_answers() -> answer.Answers:
name = inquirer.text("What's the task name?", default="hw0") name = inquirer.text("What's the task name?", default="hw0")
language: answer.LanguageInterface = inquirer.list_input( language = inquirer.list_input(
"What's the language?", choices=answer.LANGUAGES "What's the language?", choices=[(cls.name, cls) for cls in answer.LANGUAGES]
) )
language = cast(Type[answer.LanguageInterface], language)
if inquirer.confirm("Load content from templates?", default=True): if inquirer.confirm("Load content from templates?", default=True):
answers = inquirer.prompt(language.get_template_questions()) answers = inquirer.prompt(language.get_template_questions())
templates_dir = resources.files(f"joj3_config_generator.templates").joinpath( templates_dir = resources.files(f"joj3_config_generator.templates").joinpath(
language.__str__() language.name
) )
template_file_path = answers["template_file"] template_file_path = answers["template_file"]
template_file_content = Path(templates_dir / template_file_path).read_text() template_file_content = Path(templates_dir / template_file_path).read_text()

View File

@ -1,16 +1,14 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from enum import Enum from enum import Enum
from importlib import resources from importlib import resources
from typing import Any, ClassVar, Dict, List from typing import Any, ClassVar, Dict, List, Type
import inquirer import inquirer
from pydantic import BaseModel, ConfigDict from pydantic import BaseModel, ConfigDict
class LanguageInterface(ABC): class LanguageInterface(ABC):
@classmethod name: ClassVar[str]
@abstractmethod
def __str__(cls) -> str: ...
@abstractmethod @abstractmethod
class Stage(str, Enum): ... class Stage(str, Enum): ...
@ -36,7 +34,7 @@ class LanguageInterface(ABC):
@classmethod @classmethod
def get_template_questions(cls) -> List[Any]: def get_template_questions(cls) -> List[Any]:
templates_dir = resources.files(f"joj3_config_generator.templates").joinpath( templates_dir = resources.files(f"joj3_config_generator.templates").joinpath(
cls.__str__() cls.name
) )
choices = [] choices = []
for entry in templates_dir.iterdir(): for entry in templates_dir.iterdir():
@ -52,9 +50,7 @@ class LanguageInterface(ABC):
class Cpp(LanguageInterface): class Cpp(LanguageInterface):
@classmethod name = "C++"
def __str__(cls) -> str:
return "C++"
class Stage(str, Enum): class Stage(str, Enum):
COMPILATION = "Compilation" COMPILATION = "Compilation"
@ -72,24 +68,23 @@ class Cpp(LanguageInterface):
@classmethod @classmethod
def get_attribute_questions(cls) -> List[Any]: def get_attribute_questions(cls) -> List[Any]:
attribute: Cpp.Attribute = cls.attribute
return [ return [
inquirer.Text( inquirer.Text(
name="compile_command", name="compile_command",
message="Compile command", message="Compile command",
default=cls.attribute.compile_command, default=attribute.compile_command,
), ),
inquirer.Text( inquirer.Text(
name="run_command", name="run_command",
message="Run command", message="Run command",
default=cls.attribute.run_command, default=attribute.run_command,
), ),
] ]
class Python(LanguageInterface): class Python(LanguageInterface):
@classmethod name = "Python"
def __str__(cls) -> str:
return "Python"
class Stage(str, Enum): class Stage(str, Enum):
RUN = "Run" RUN = "Run"
@ -102,19 +97,18 @@ class Python(LanguageInterface):
@classmethod @classmethod
def get_attribute_questions(cls) -> List[Any]: def get_attribute_questions(cls) -> List[Any]:
attribute: Python.Attribute = cls.attribute
return [ return [
inquirer.Text( inquirer.Text(
name="run_command", name="run_command",
message="Run command", message="Run command",
default=cls.attribute.run_command, default=attribute.run_command,
), ),
] ]
class Rust(LanguageInterface): class Rust(LanguageInterface):
@classmethod name = "Rust"
def __str__(cls) -> str:
return "Rust"
class Stage(str, Enum): class Stage(str, Enum):
COMPILATION = "Compilation" COMPILATION = "Compilation"
@ -129,19 +123,20 @@ class Rust(LanguageInterface):
@classmethod @classmethod
def get_attribute_questions(cls) -> List[Any]: def get_attribute_questions(cls) -> List[Any]:
attribute: Rust.Attribute = cls.attribute
return [] return []
LANGUAGES = [ LANGUAGES: List[Type[LanguageInterface]] = [
Cpp(), Cpp,
Python(), Python,
Rust(), Rust,
] ]
class Answers(BaseModel): class Answers(BaseModel):
name: str name: str
language: LanguageInterface language: Type[LanguageInterface]
template_file_content: str = "" template_file_content: str = ""
model_config = ConfigDict(arbitrary_types_allowed=True) model_config = ConfigDict(arbitrary_types_allowed=True)

View File

@ -14,14 +14,14 @@ def get_task_conf_from_answers(answers: answer.Answers) -> task.Config:
) )
language = answers.language language = answers.language
transformer_dict = get_transformer_dict() transformer_dict = get_transformer_dict()
transformer = transformer_dict[type(language)] transformer = transformer_dict[language]
stages = transformer(language) stages = transformer(language)
return task.Config(task=task.Task(name=answers.name), stages=stages) return task.Config(task=task.Task(name=answers.name), stages=stages)
def get_transformer_dict() -> Dict[ def get_transformer_dict() -> Dict[
Type[Any], Type[answer.LanguageInterface],
Callable[[Any], List[task.Stage]], Callable[[Type[Any]], List[task.Stage]],
]: ]:
return { return {
answer.Cpp: get_cpp_stages, answer.Cpp: get_cpp_stages,
@ -31,7 +31,7 @@ def get_transformer_dict() -> Dict[
# TODO: implement # TODO: implement
def get_cpp_stages(language: answer.Cpp) -> List[task.Stage]: def get_cpp_stages(language: Type[answer.Cpp]) -> List[task.Stage]:
stages = language.stages stages = language.stages
attribute: answer.Cpp.Attribute = language.attribute attribute: answer.Cpp.Attribute = language.attribute
task_stages = [] task_stages = []
@ -47,14 +47,14 @@ def get_cpp_stages(language: answer.Cpp) -> List[task.Stage]:
# TODO: implement # TODO: implement
def get_python_stages(language: answer.Python) -> List[task.Stage]: def get_python_stages(language: Type[answer.Python]) -> List[task.Stage]:
stages = language.stages stages = language.stages
attribute: answer.Python.Attribute = language.attribute attribute: answer.Python.Attribute = language.attribute
return [] return []
# TODO: implement # TODO: implement
def get_rust_stages(language: answer.Rust) -> List[task.Stage]: def get_rust_stages(language: Type[answer.Rust]) -> List[task.Stage]:
stages = language.stages stages = language.stages
attribute: answer.Rust.Attribute = language.attribute attribute: answer.Rust.Attribute = language.attribute
return [] return []

View File

@ -13,7 +13,7 @@ def load_case(case_name: str) -> None:
answers_json_path = root / case_name / "answers.json" answers_json_path = root / case_name / "answers.json"
task_toml_path = root / case_name / "task.toml" task_toml_path = root / case_name / "task.toml"
answers_dict = json.loads(answers_json_path.read_text()) answers_dict = json.loads(answers_json_path.read_text())
language = next(x for x in answer.LANGUAGES if str(x) == answers_dict["language"]) language = next(x for x in answer.LANGUAGES if x.name == answers_dict["language"])
language.set_stages(answers_dict["stages"]) language.set_stages(answers_dict["stages"])
language.set_attribute(answers_dict["attribute"]) language.set_attribute(answers_dict["attribute"])
answers = answer.Answers(name=answers_dict["name"], language=language) answers = answer.Answers(name=answers_dict["name"], language=language)