From 3e252e1ff0646eeff71b855b66bf8cf1920bffcb Mon Sep 17 00:00:00 2001
From: Boming Zhang <bomingzh@sjtu.edu.cn>
Date: Sun, 2 Mar 2025 14:26:25 -0500
Subject: [PATCH] chore: better defaults and code style

---
 joj3_config_generator/models/repo.py     |   5 +-
 joj3_config_generator/processers/joj1.py |  18 +---
 joj3_config_generator/processers/repo.py | 119 ++++++++---------------
 3 files changed, 49 insertions(+), 93 deletions(-)

diff --git a/joj3_config_generator/models/repo.py b/joj3_config_generator/models/repo.py
index da0c6bb..41b7d94 100644
--- a/joj3_config_generator/models/repo.py
+++ b/joj3_config_generator/models/repo.py
@@ -1,5 +1,6 @@
+import socket
 from pathlib import Path
-from typing import List, Optional
+from typing import List
 
 from pydantic import BaseModel, Field
 
@@ -25,4 +26,4 @@ class Config(BaseModel):
     groups: Groups = Groups()
     root: Path = Path(".")
     path: Path = Path("repo.toml")
-    grading_repo_name: Optional[str] = None
+    grading_repo_name: str = f"{socket.gethostname().split('-')[0]}-joj"
diff --git a/joj3_config_generator/processers/joj1.py b/joj3_config_generator/processers/joj1.py
index 5eeffc3..ce34599 100644
--- a/joj3_config_generator/processers/joj1.py
+++ b/joj3_config_generator/processers/joj1.py
@@ -1,6 +1,5 @@
 from joj3_config_generator.models import joj1, task
 from joj3_config_generator.models.common import Memory, Time
-from joj3_config_generator.models.const import DEFAULT_CPU_LIMIT, DEFAULT_MEMORY_LIMIT
 
 
 def get_joj1_run_stage(joj1_config: joj1.Config) -> task.Stage:
@@ -10,10 +9,7 @@ def get_joj1_run_stage(joj1_config: joj1.Config) -> task.Stage:
             task.Stage(
                 score=case.score,
                 command=case.execute_args if case.execute_args else "",
-                limit=task.Limit(
-                    cpu=Time(case.time) if case.time else DEFAULT_CPU_LIMIT,
-                    mem=(Memory(case.memory) if case.memory else DEFAULT_MEMORY_LIMIT),
-                ),
+                limit=task.Limit(cpu=Time(case.time), mem=Memory(case.memory)),
             )
         )
     for i, case in enumerate(joj1_config.cases):
@@ -24,16 +20,8 @@ def get_joj1_run_stage(joj1_config: joj1.Config) -> task.Stage:
         parsers=["diff", "result-status"],
         score=100,
         limit=task.Limit(
-            cpu=(
-                Time(joj1_config.cases[0].time)
-                if joj1_config.cases[0].time is not None
-                else DEFAULT_CPU_LIMIT
-            ),
-            mem=(
-                Memory(joj1_config.cases[0].memory)
-                if joj1_config.cases[0].memory is not None
-                else DEFAULT_MEMORY_LIMIT
-            ),
+            cpu=Time(joj1_config.cases[0].time),
+            mem=Memory(joj1_config.cases[0].memory),
         ),
         cases={f"case{i}": cases_conf[i] for i, _ in enumerate(cases_conf)},
     )  # TODO: no strong pattern match here, use dict instead
diff --git a/joj3_config_generator/processers/repo.py b/joj3_config_generator/processers/repo.py
index f8396af..5b2270d 100644
--- a/joj3_config_generator/processers/repo.py
+++ b/joj3_config_generator/processers/repo.py
@@ -1,28 +1,20 @@
 import hashlib
-import shlex
-import socket
 from pathlib import Path
+from typing import List
 
 from joj3_config_generator.models import repo, result
 
 
-def get_grading_repo_name(repo_conf: repo.Config) -> str:
-    host_name = "ece280"
-    host_name = socket.gethostname()
-    grading_repo_name = (
-        repo_conf.grading_repo_name
-        if repo_conf.grading_repo_name is not None
-        else f"{host_name.split('-')[0]}-joj"
-    )
-    return grading_repo_name
-
-
 def get_teapot_stage(repo_conf: repo.Config) -> result.StageDetail:
-    args_ = ""
-    args_ = (
-        args_
-        + f"/usr/local/bin/joint-teapot joj3-all-env /home/tt/.config/teapot/teapot.env --grading-repo-name {get_grading_repo_name(repo_conf)} --max-total-score {repo_conf.max_total_score}"
-    )
+    args = [
+        "/usr/local/bin/joint-teapot",
+        "joj3-all-env",
+        "/home/tt/.config/teapot/teapot.env",
+        "--grading-repo-name",
+        repo_conf.grading_repo_name,
+        "--max-total-score",
+        str(repo_conf.max_total_score),
+    ]
 
     stage_conf = result.StageDetail(
         name="teapot",
@@ -30,7 +22,7 @@ def get_teapot_stage(repo_conf: repo.Config) -> result.StageDetail:
             name="local",
             with_=result.ExecutorWith(
                 default=result.Cmd(
-                    args=shlex.split(args_),
+                    args=args,
                     env=["LOG_FILE_PATH=/home/tt/.cache/joint-teapot-debug.log"],
                 ),
                 cases=[],
@@ -41,49 +33,37 @@ def get_teapot_stage(repo_conf: repo.Config) -> result.StageDetail:
     return stage_conf
 
 
-def get_healthcheck_args(repo_conf: repo.Config) -> str:
-    repoSize = repo_conf.max_size
-    immutable = repo_conf.files.immutable
-    repo_size = f"-repoSize={str(repoSize)} "
-    required_files = repo_conf.files.required
-
-    for i, meta in enumerate(required_files):
-        required_files[i] = f"-meta={meta} "
-
-    immutable_files = "-checkFileNameList="
-    for i, name in enumerate(immutable):
-        if i == len(immutable) - 1:
-            immutable_files = immutable_files + name + " "
-        else:
-            immutable_files = immutable_files + name + ","
-    args = "/usr/local/bin/repo-health-checker -root=. "
-    args = args + repo_size
-    for meta in required_files:
-        args = args + meta
-
-    args = args + get_hash(immutable, repo_conf)
-
-    args = args + immutable_files
-
-    return args
+def get_healthcheck_args(repo_conf: repo.Config) -> List[str]:
+    return [
+        "/usr/local/bin/repo-health-checker",
+        "-root=.",
+        f"-repoSize={str(repo_conf.max_size)}",
+        *[f"-meta={meta}" for meta in repo_conf.files.required],
+        get_hash(repo_conf),
+        f"-checkFileNameList={','.join(repo_conf.files.immutable)}",
+    ]
 
 
-def get_debug_args(repo_conf: repo.Config) -> str:
-    args = ""
-    args = (
-        args
-        + f"/usr/local/bin/joint-teapot joj3-check-env /home/tt/.config/teapot/teapot.env --grading-repo-name {get_grading_repo_name(repo_conf)} --group-config "
-    )
+def get_debug_args(repo_conf: repo.Config) -> List[str]:
     group_config = ""
     for i, name in enumerate(repo_conf.groups.name):
         group_config = (
             group_config
-            + f"{name}={repo_conf.groups.max_count[i]}:{repo_conf.groups.time_period_hour[i]},"
+            + f"{name}="
+            + f"{repo_conf.groups.max_count[i]}:"
+            + f"{repo_conf.groups.time_period_hour[i]},"
         )
     # default value hardcoded
     group_config = group_config + "=100:24"
-    args = args + group_config
-    return args
+    return [
+        "/usr/local/bin/joint-teapot",
+        "joj3-check-env",
+        "/home/tt/.config/teapot/teapot.env",
+        "--grading-repo-name",
+        repo_conf.grading_repo_name,
+        "--group-config",
+        group_config,
+    ]
 
 
 def get_healthcheck_config(repo_conf: repo.Config) -> result.StageDetail:
@@ -96,10 +76,10 @@ def get_healthcheck_config(repo_conf: repo.Config) -> result.StageDetail:
                 default=result.Cmd(),
                 cases=[
                     result.OptionalCmd(
-                        args=shlex.split(get_healthcheck_args(repo_conf)),
+                        args=get_healthcheck_args(repo_conf),
                     ),
                     result.OptionalCmd(
-                        args=shlex.split(get_debug_args(repo_conf)),
+                        args=get_debug_args(repo_conf),
                         env=["LOG_FILE_PATH=/home/tt/.cache/joint-teapot-debug.log"],
                     ),
                 ],
@@ -121,24 +101,11 @@ def calc_sha256sum(file_path: Path) -> str:
     return sha256_hash.hexdigest()
 
 
-def get_hash(
-    immutable_files: list[str], repo_conf: repo.Config
-) -> str:  # input should be a list
-    repo_path = (repo_conf.root / repo_conf.path).parent
-    file_path = Path(f"{repo_path}/immutable_files")
-    immutable_hash = []
-    immutable_files_ = []
-    for i, file in enumerate(immutable_files):
-        immutable_files_.append(file_path.joinpath(file.rsplit("/", 1)[-1]))
-
-    for i, file_ in enumerate(immutable_files_):
-        immutable_hash.append(calc_sha256sum(file_))
-
-    hash_check = "-checkFileSumList="
-
-    for i, file in enumerate(immutable_hash):
-        if i == len(immutable_hash) - 1:
-            hash_check = hash_check + file + " "
-        else:
-            hash_check = hash_check + file + ","
-    return hash_check
+def get_hash(repo_conf: repo.Config) -> str:  # input should be a list
+    base_dir = (repo_conf.root / repo_conf.path).parent
+    immutable_dir = base_dir / "immutable_files"
+    immutable_files = [
+        immutable_dir / Path(file).name for file in repo_conf.files.immutable
+    ]
+    immutable_hash = [calc_sha256sum(file) for file in immutable_files]
+    return f"-checkFileSumList={','.join(immutable_hash)}"