diff --git a/joj3_config_generator/processers/task.py b/joj3_config_generator/processers/task.py index e21b208..52bcc74 100644 --- a/joj3_config_generator/processers/task.py +++ b/joj3_config_generator/processers/task.py @@ -152,55 +152,42 @@ def fix_file(file_parser_config: task.ParserFile, file_parser: result.Parser) -> def fix_diff( task_stage: task.Stage, diff_parser_config: result.Parser, - diff_executor: result.Executor, + diff_executor_config: result.Executor, base_dir: Path, ) -> None: - skip = task_stage.skip - cases = task_stage.cases - finalized_cases = [case for case in cases if case not in skip] - + valid_cases = ( + (case, task_stage.cases[case]) + for case in task_stage.cases + if case not in task_stage.skip and case in task_stage.cases + ) stage_cases = [] parser_cases = [] - - for case in finalized_cases: - case_stage = task_stage.cases.get(case) if task_stage.cases else None - if not case_stage: - continue - - cpu_limit = case_stage.limit.cpu - clock_limit = 2 * case_stage.limit.cpu - memory_limit = case_stage.limit.mem - command = case_stage.command - stdin = case_stage.in_ if case_stage.in_ != "" else f"{case}.in" - stdout = case_stage.out_ if case_stage.out_ != "" else f"{case}.out" - + for case, case_stage in valid_cases: stage_cases.append( result.OptionalCmd( stdin=result.LocalFile( - src=str(base_dir / stdin), + src=str(base_dir / (case_stage.in_ or f"{case}.in")) ), - args=shlex.split(command) if command else None, - cpu_limit=cpu_limit, - clock_limit=clock_limit, - memory_limit=memory_limit, + args=shlex.split(case_stage.command) if case_stage.command else None, + cpu_limit=case_stage.limit.cpu, + clock_limit=2 * case_stage.limit.cpu, + memory_limit=case_stage.limit.mem, proc_limit=50, ) ) - diff_output = case_stage.diff.output parser_cases.append( result.DiffCasesConfig( outputs=[ result.DiffOutputConfig( - score=diff_output.score, + score=case_stage.diff.output.score, file_name="stdout", - answer_path=str(base_dir / stdout), - force_quit_on_diff=diff_output.force_quit, - always_hide=diff_output.hide, - compare_space=not diff_output.ignore_spaces, + answer_path=str(base_dir / (case_stage.out_ or f"{case}.out")), + force_quit_on_diff=case_stage.diff.output.force_quit, + always_hide=case_stage.diff.output.hide, + compare_space=not case_stage.diff.output.ignore_spaces, ) ] ) ) - - diff_executor.with_.cases = stage_cases + diff_executor_config.with_.cases = stage_cases diff_parser_config.with_ = result.DiffConfig(name="diff", cases=parser_cases)