import json
from pathlib import Path
from typing import Any, Dict, Tuple, Type, cast

import inquirer
import tomli
import yaml
from pydantic import BaseModel

from joj3_config_generator.models import answer, joj1, repo, task
from joj3_config_generator.models.common import Memory, Time
from joj3_config_generator.utils.logger import logger


def load_joj3_task_toml_answers() -> answer.Answers:
    name = inquirer.text("What's the task name?", default="hw0")
    language = inquirer.list_input(
        "What's the language?", choices=[(cls.name, cls) for cls in answer.LANGUAGES]
    )
    language = cast(Type[answer.LanguageInterface], language)
    if inquirer.confirm("Load content from templates?", default=True):
        answers = inquirer.prompt(language.get_template_questions())
        template_file_content: str = answers["template_file_content"]
        return answer.Answers(
            name=name, language=language, template_file_content=template_file_content
        )
    stages = inquirer.checkbox(
        "What's the stages?",
        choices=[member.value for member in language.Stage],
        default=[member.value for member in language.Stage],
    )
    language.set_stages(stages)
    attribute = inquirer.prompt(language.get_attribute_questions())
    language.set_attribute(attribute)
    return answer.Answers(name=name, language=language)


def load_joj1_yaml(yaml_path: Path) -> joj1.Config:
    joj1_obj = yaml.safe_load(yaml_path.read_text())
    return joj1.Config(**joj1_obj)


def load_joj3_toml(
    root_path: Path, repo_toml_path: Path, task_toml_path: Path
) -> Tuple[repo.Config, task.Config]:
    def check_unnecessary_fields(
        pydantic_model_type: Type[BaseModel],
        input_dict: Dict[str, Any],
        file_path: Path,
        current_path: str = "",
    ) -> None:
        def format_value_for_toml_warning(value: Any) -> str:
            if isinstance(value, str):
                escaped_value = value.replace("\\", "\\\\").replace('"', '\\"')
                return f'"{escaped_value}"'
            elif isinstance(value, bool):
                return str(value).lower()
            elif isinstance(value, (int, float)):
                return str(value)
            elif isinstance(value, Path):
                escaped_value = str(value).replace("\\", "\\\\").replace('"', '\\"')
                return f'"{escaped_value}"'
            elif isinstance(value, list):
                formatted_elements = [
                    format_value_for_toml_warning(item) for item in value
                ]
                return f"[{', '.join(formatted_elements)}]"
            elif isinstance(value, dict):
                return json.dumps(value, separators=(",", ":"))
            elif value is None:
                return "None"
            else:
                return repr(value)

        default_instance = pydantic_model_type.model_construct()
        for field_name, field_info in pydantic_model_type.model_fields.items():
            should_warn = False
            full_field_path = (
                f"{current_path}.{field_name}" if current_path else field_name
            )
            toml_field_name = field_name
            if field_info.alias in input_dict:
                toml_field_name = field_info.alias
            if toml_field_name not in input_dict:
                continue
            toml_value = input_dict[toml_field_name]
            default_value = getattr(default_instance, field_name)
            # Handle List[Pydantic.BaseModel]
            if (
                field_info.annotation is not None
                and hasattr(field_info.annotation, "__origin__")
                and field_info.annotation.__origin__ is list
                and hasattr(field_info.annotation, "__args__")
                and len(field_info.annotation.__args__) == 1
                and isinstance(field_info.annotation.__args__[0], type)
                and issubclass(field_info.annotation.__args__[0], BaseModel)
            ):
                nested_model_type = field_info.annotation.__args__[0]
                # Ensure the TOML value is a list (as expected for this type)
                if isinstance(toml_value, list):
                    for i, toml_item in enumerate(toml_value):
                        if isinstance(toml_item, dict):
                            check_unnecessary_fields(
                                nested_model_type,
                                toml_item,
                                file_path,
                                f"{full_field_path}[{i}]",
                            )
                continue
            # Handle directly nested Pydantic models (non-list)
            if isinstance(field_info.annotation, type) and issubclass(
                field_info.annotation, BaseModel
            ):
                if isinstance(toml_value, dict):
                    check_unnecessary_fields(
                        field_info.annotation,
                        toml_value,
                        file_path,
                        full_field_path,
                    )
                continue
            # Handle Path type
            elif (
                isinstance(toml_value, str)
                and isinstance(default_value, Path)
                and Path(toml_value) == default_value
            ):
                should_warn = True
            # Handle Time type
            elif isinstance(default_value, Time) and Time(toml_value) == default_value:
                should_warn = True
            # Handle Memory type
            elif (
                isinstance(default_value, Memory)
                and Memory(toml_value) == default_value
            ):
                should_warn = True
            # Handle non-model list types (e.g., List[str], List[int])
            elif (
                isinstance(toml_value, list)
                and isinstance(default_value, list)
                and toml_value == default_value
            ):
                should_warn = True
            # Handle other basic types (str, int, float, bool, dict)
            elif toml_value == default_value and toml_value != {}:
                should_warn = True
            if should_warn:
                logger.warning(
                    f"In file {file_path}, unnecessary field "
                    f"`{full_field_path} = {format_value_for_toml_warning(toml_value)}`"
                    " can be removed as it matches the default value"
                )

    repo_obj = tomli.loads(repo_toml_path.read_text())
    task_obj = tomli.loads(task_toml_path.read_text())
    repo_conf = repo.Config(**repo_obj)
    repo_conf.root = root_path
    repo_conf.path = repo_toml_path.relative_to(root_path)
    task_conf = task.Config(**task_obj)
    task_conf.root = root_path
    task_conf.path = task_toml_path.relative_to(root_path)
    check_unnecessary_fields(repo.Config, repo_obj, repo_toml_path)
    check_unnecessary_fields(task.Config, task_obj, task_toml_path)
    return repo_conf, task_conf