JOJ3-config-generator/joj3_config_generator/loader.py
张泊明518370910136 f556801884
All checks were successful
build / build (push) Successful in 3m8s
feat: check unnecessary fields (#12)
2025-06-01 21:43:37 -04:00

166 lines
6.9 KiB
Python

import json
from pathlib import Path
from typing import Any, Dict, Tuple, Type, cast
import inquirer
import tomli
import yaml
from pydantic import BaseModel
from joj3_config_generator.models import answer, joj1, repo, task
from joj3_config_generator.models.common import Memory, Time
from joj3_config_generator.utils.logger import logger
def load_joj3_task_toml_answers() -> answer.Answers:
name = inquirer.text("What's the task name?", default="hw0")
language = inquirer.list_input(
"What's the language?", choices=[(cls.name, cls) for cls in answer.LANGUAGES]
)
language = cast(Type[answer.LanguageInterface], language)
if inquirer.confirm("Load content from templates?", default=True):
answers = inquirer.prompt(language.get_template_questions())
template_file_content: str = answers["template_file_content"]
return answer.Answers(
name=name, language=language, template_file_content=template_file_content
)
stages = inquirer.checkbox(
"What's the stages?",
choices=[member.value for member in language.Stage],
default=[member.value for member in language.Stage],
)
language.set_stages(stages)
attribute = inquirer.prompt(language.get_attribute_questions())
language.set_attribute(attribute)
return answer.Answers(name=name, language=language)
def load_joj1_yaml(yaml_path: Path) -> joj1.Config:
joj1_obj = yaml.safe_load(yaml_path.read_text())
return joj1.Config(**joj1_obj)
def load_joj3_toml(
root_path: Path, repo_toml_path: Path, task_toml_path: Path
) -> Tuple[repo.Config, task.Config]:
def check_unnecessary_fields(
pydantic_model_type: Type[BaseModel],
input_dict: Dict[str, Any],
file_path: Path,
current_path: str = "",
) -> None:
def format_value_for_toml_warning(value: Any) -> str:
if isinstance(value, str):
escaped_value = value.replace("\\", "\\\\").replace('"', '\\"')
return f'"{escaped_value}"'
elif isinstance(value, bool):
return str(value).lower()
elif isinstance(value, (int, float)):
return str(value)
elif isinstance(value, Path):
escaped_value = str(value).replace("\\", "\\\\").replace('"', '\\"')
return f'"{escaped_value}"'
elif isinstance(value, list):
formatted_elements = [
format_value_for_toml_warning(item) for item in value
]
return f"[{', '.join(formatted_elements)}]"
elif isinstance(value, dict):
return json.dumps(value, separators=(",", ":"))
elif value is None:
return "None"
else:
return repr(value)
default_instance = pydantic_model_type.model_construct()
for field_name, field_info in pydantic_model_type.model_fields.items():
should_warn = False
full_field_path = (
f"{current_path}.{field_name}" if current_path else field_name
)
toml_field_name = field_name
if field_info.alias in input_dict:
toml_field_name = field_info.alias
if toml_field_name not in input_dict:
continue
toml_value = input_dict[toml_field_name]
default_value = getattr(default_instance, field_name)
# Handle List[Pydantic.BaseModel]
if (
field_info.annotation is not None
and hasattr(field_info.annotation, "__origin__")
and field_info.annotation.__origin__ is list
and hasattr(field_info.annotation, "__args__")
and len(field_info.annotation.__args__) == 1
and isinstance(field_info.annotation.__args__[0], type)
and issubclass(field_info.annotation.__args__[0], BaseModel)
):
nested_model_type = field_info.annotation.__args__[0]
# Ensure the TOML value is a list (as expected for this type)
if isinstance(toml_value, list):
for i, toml_item in enumerate(toml_value):
if isinstance(toml_item, dict):
check_unnecessary_fields(
nested_model_type,
toml_item,
file_path,
f"{full_field_path}[{i}]",
)
continue
# Handle directly nested Pydantic models (non-list)
if isinstance(field_info.annotation, type) and issubclass(
field_info.annotation, BaseModel
):
if isinstance(toml_value, dict):
check_unnecessary_fields(
field_info.annotation,
toml_value,
file_path,
full_field_path,
)
continue
# Handle Path type
elif (
isinstance(toml_value, str)
and isinstance(default_value, Path)
and Path(toml_value) == default_value
):
should_warn = True
# Handle Time type
elif isinstance(default_value, Time) and Time(toml_value) == default_value:
should_warn = True
# Handle Memory type
elif (
isinstance(default_value, Memory)
and Memory(toml_value) == default_value
):
should_warn = True
# Handle non-model list types (e.g., List[str], List[int])
elif (
isinstance(toml_value, list)
and isinstance(default_value, list)
and toml_value == default_value
):
should_warn = True
# Handle other basic types (str, int, float, bool, dict)
elif toml_value == default_value and toml_value != {}:
should_warn = True
if should_warn:
logger.warning(
f"In file {file_path}, unnecessary field "
f"`{full_field_path} = {format_value_for_toml_warning(toml_value)}`"
" can be removed as it matches the default value"
)
repo_obj = tomli.loads(repo_toml_path.read_text())
task_obj = tomli.loads(task_toml_path.read_text())
repo_conf = repo.Config(**repo_obj)
repo_conf.root = root_path
repo_conf.path = repo_toml_path.relative_to(root_path)
task_conf = task.Config(**task_obj)
task_conf.root = root_path
task_conf.path = task_toml_path.relative_to(root_path)
check_unnecessary_fields(repo.Config, repo_obj, repo_toml_path)
check_unnecessary_fields(task.Config, task_obj, task_toml_path)
return repo_conf, task_conf