170 lines
7.1 KiB
Python
170 lines
7.1 KiB
Python
import json
|
|
from pathlib import Path
|
|
from typing import Any, Dict, Tuple, Type, cast
|
|
|
|
import inquirer
|
|
import tomli
|
|
import yaml
|
|
from pydantic import BaseModel
|
|
|
|
from joj3_config_generator.models import answer, joj1, repo, task
|
|
from joj3_config_generator.models.common import Memory, Time
|
|
from joj3_config_generator.utils.logger import logger
|
|
|
|
|
|
def load_joj3_task_toml_answers() -> answer.Answers:
|
|
name = inquirer.text("What's the task name?", default="hw0")
|
|
language = inquirer.list_input(
|
|
"What's the language?", choices=[(cls.name, cls) for cls in answer.LANGUAGES]
|
|
)
|
|
language = cast(Type[answer.LanguageInterface], language)
|
|
if inquirer.confirm("Load content from templates?", default=True):
|
|
questions = language.get_template_questions()
|
|
if not questions[0].choices:
|
|
logger.warning("No template files found for the selected language. ")
|
|
return answer.Answers(name=name, language=language)
|
|
answers = inquirer.prompt(questions)
|
|
template_file_content: str = answers["template_file_content"]
|
|
return answer.Answers(
|
|
name=name, language=language, template_file_content=template_file_content
|
|
)
|
|
stages = inquirer.checkbox(
|
|
"What's the stages?",
|
|
choices=[member.value for member in language.Stage],
|
|
default=[member.value for member in language.Stage],
|
|
)
|
|
language.set_stages(stages)
|
|
attribute = inquirer.prompt(language.get_attribute_questions())
|
|
language.set_attribute(attribute)
|
|
return answer.Answers(name=name, language=language)
|
|
|
|
|
|
def load_joj1_yaml(yaml_path: Path) -> joj1.Config:
|
|
joj1_obj = yaml.safe_load(yaml_path.read_text())
|
|
return joj1.Config(**joj1_obj)
|
|
|
|
|
|
def load_joj3_toml(
|
|
root_path: Path, repo_toml_path: Path, task_toml_path: Path
|
|
) -> Tuple[repo.Config, task.Config]:
|
|
def check_unnecessary_fields(
|
|
pydantic_model_type: Type[BaseModel],
|
|
input_dict: Dict[str, Any],
|
|
file_path: Path,
|
|
current_path: str = "",
|
|
) -> None:
|
|
def format_value_for_toml_warning(value: Any) -> str:
|
|
if isinstance(value, str):
|
|
escaped_value = value.replace("\\", "\\\\").replace('"', '\\"')
|
|
return f'"{escaped_value}"'
|
|
elif isinstance(value, bool):
|
|
return str(value).lower()
|
|
elif isinstance(value, (int, float)):
|
|
return str(value)
|
|
elif isinstance(value, Path):
|
|
escaped_value = str(value).replace("\\", "\\\\").replace('"', '\\"')
|
|
return f'"{escaped_value}"'
|
|
elif isinstance(value, list):
|
|
formatted_elements = [
|
|
format_value_for_toml_warning(item) for item in value
|
|
]
|
|
return f"[{', '.join(formatted_elements)}]"
|
|
elif isinstance(value, dict):
|
|
return json.dumps(value, separators=(",", ":"))
|
|
elif value is None:
|
|
return "None"
|
|
else:
|
|
return repr(value)
|
|
|
|
default_instance = pydantic_model_type.model_construct()
|
|
for field_name, field_info in pydantic_model_type.model_fields.items():
|
|
should_warn = False
|
|
full_field_path = (
|
|
f"{current_path}.{field_name}" if current_path else field_name
|
|
)
|
|
toml_field_name = field_name
|
|
if field_info.alias in input_dict:
|
|
toml_field_name = field_info.alias
|
|
if toml_field_name not in input_dict:
|
|
continue
|
|
toml_value = input_dict[toml_field_name]
|
|
default_value = getattr(default_instance, field_name)
|
|
# Handle List[Pydantic.BaseModel]
|
|
if (
|
|
field_info.annotation is not None
|
|
and hasattr(field_info.annotation, "__origin__")
|
|
and field_info.annotation.__origin__ is list
|
|
and hasattr(field_info.annotation, "__args__")
|
|
and len(field_info.annotation.__args__) == 1
|
|
and isinstance(field_info.annotation.__args__[0], type)
|
|
and issubclass(field_info.annotation.__args__[0], BaseModel)
|
|
):
|
|
nested_model_type = field_info.annotation.__args__[0]
|
|
# Ensure the TOML value is a list (as expected for this type)
|
|
if isinstance(toml_value, list):
|
|
for i, toml_item in enumerate(toml_value):
|
|
if isinstance(toml_item, dict):
|
|
check_unnecessary_fields(
|
|
nested_model_type,
|
|
toml_item,
|
|
file_path,
|
|
f"{full_field_path}[{i}]",
|
|
)
|
|
continue
|
|
# Handle directly nested Pydantic models (non-list)
|
|
if isinstance(field_info.annotation, type) and issubclass(
|
|
field_info.annotation, BaseModel
|
|
):
|
|
if isinstance(toml_value, dict):
|
|
check_unnecessary_fields(
|
|
field_info.annotation,
|
|
toml_value,
|
|
file_path,
|
|
full_field_path,
|
|
)
|
|
continue
|
|
# Handle Path type
|
|
elif (
|
|
isinstance(toml_value, str)
|
|
and isinstance(default_value, Path)
|
|
and Path(toml_value) == default_value
|
|
):
|
|
should_warn = True
|
|
# Handle Time type
|
|
elif isinstance(default_value, Time) and Time(toml_value) == default_value:
|
|
should_warn = True
|
|
# Handle Memory type
|
|
elif (
|
|
isinstance(default_value, Memory)
|
|
and Memory(toml_value) == default_value
|
|
):
|
|
should_warn = True
|
|
# Handle non-model list types (e.g., List[str], List[int])
|
|
elif (
|
|
isinstance(toml_value, list)
|
|
and isinstance(default_value, list)
|
|
and toml_value == default_value
|
|
):
|
|
should_warn = True
|
|
# Handle other basic types (str, int, float, bool, dict)
|
|
elif toml_value == default_value and toml_value != {}:
|
|
should_warn = True
|
|
if should_warn:
|
|
logger.warning(
|
|
f"In file {file_path}, unnecessary field "
|
|
f"`{full_field_path} = {format_value_for_toml_warning(toml_value)}`"
|
|
" can be removed as it matches the default value"
|
|
)
|
|
|
|
repo_obj = tomli.loads(repo_toml_path.read_text())
|
|
task_obj = tomli.loads(task_toml_path.read_text())
|
|
repo_conf = repo.Config(**repo_obj)
|
|
repo_conf.root = root_path
|
|
repo_conf.path = repo_toml_path.relative_to(root_path)
|
|
task_conf = task.Config(**task_obj)
|
|
task_conf.root = root_path
|
|
task_conf.path = task_toml_path.relative_to(root_path)
|
|
check_unnecessary_fields(repo.Config, repo_obj, repo_toml_path)
|
|
check_unnecessary_fields(task.Config, task_obj, task_toml_path)
|
|
return repo_conf, task_conf
|