Skip to content

supaernova.configs.steps.model

[docs] module supaernova.configs.steps.model

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
from typing import Any, TypeVar, ClassVar, get_args
from collections.abc import Callable

from pydantic import (
    Field,
    BaseModel,
    ConfigDict,
    model_validator,
)

from .steps import StepConfig
from .backends import BACKENDS, BACKENDS_STR, AbstractModelConfig


class AbstractModelStepConfig[Backend: str, ModelConfig: AbstractModelConfig](
    StepConfig
):
    model_backend: ClassVar[dict[str, Callable[[], type[ModelConfig]]]]

    # --- Models ---
    model: ModelConfig
    models: list[ModelConfig] | None = Field(None, validation_alias="variant")

    @model_validator(mode="before")
    @classmethod
    def prep_model_config(cls, data: Any) -> Any:
        if isinstance(data, dict):
            if "model" not in data:
                err = f"No Base Model has been defined. Please define one in [{cls.id}.model]"
                raise ValueError(err)

            if isinstance(data["model"], AbstractModelConfig):
                data["variant"] = data["models"]
                data.pop("models", None)
            else:
                default_model_config = {
                    "paths": data.get("paths"),
                    "config": data.get("config"),
                    "log": data.get("log"),
                }
                base_model_config = {**default_model_config, **data.get("model", {})}
                data["model"] = base_model_config

                model_configs = [
                    data["model"],
                    *[
                        {**base_model_config, **model_config}
                        for model_config in data.get("variant") or []
                    ],
                ]
                data.pop("variant", None)
                data["variant"] = []
                for i, model_config in enumerate(model_configs):
                    backend = model_config.get("backend")
                    if backend is None:
                        err = f"{'Base' if i == 0 else f'Variant {i}'} Model is missing a backend key. Please choose from {BACKENDS_STR}"
                        raise ValueError(err)
                    model_config_cls = None
                    for backend_name in BACKENDS:
                        if backend in get_args(BACKENDS[backend_name]):
                            model_config_cls = cls.model_backend[backend_name]
                    if model_config_cls is None:
                        err = f"Unknown backend: {backend}. Please choose from {BACKENDS_STR}"
                        raise ValueError(err)
                    model_variant = model_config_cls().from_config(model_config)
                    data["variant"].append(model_variant)
        return data