Skip to content

supaernova.steps.model

[docs] module supaernova.steps.model

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
from typing import TYPE_CHECKING, get_args, override

from .steps import SNPAEStep

if TYPE_CHECKING:
    from typing import Any
    from logging import Logger

    from supaernova.configs.paths import PathConfig
    from supaernova.configs.globals import GlobalConfig
    from supaernova.configs.steps.model import AbstractModelStepConfig
    from supaernova.configs.steps.steps import AbstractStepResult
    from supaernova.configs.steps.backends import AbstractModelConfig

    from .backends import AbstractModel


class AbstractModelStep[Backend: str, Model: AbstractModel[Backend]](SNPAEStep):
    def __init__(
        self, config: "AbstractModelStepConfig[Backend, AbstractModelConfig]"
    ) -> None:
        # --- Superclass Variables ---
        self.options: AbstractModelStepConfig[Backend, AbstractModelConfig]
        self.config: GlobalConfig
        self.paths: PathConfig
        self.log: Logger
        self.force: bool
        self.verbose: bool
        super().__init__(config)

        self.models: list[Model]
        self.n_models: int
        self.results: list[AbstractStepResult]

    @override
    def _setup(self, *args: "Any", **kwargs: "Any") -> None:
        model_step: type[Model] = get_args(self.__orig_bases__[0])[1]
        self.models = [model_step(model) for model in self.options.models or []]
        self.n_models = len(self.models)

    @override
    def _completed(self) -> bool:
        return all(model.completed() for model in self.models)

    @override
    def _load(self) -> None:
        for model in self.models:
            model.load()

    @override
    def _run(self) -> None:
        for model in self.models:
            model.run()

    @override
    def _result(self) -> None:
        for model in self.models:
            model.result()
        self.results = [model.results for model in self.models]

    @override
    def _analyse(self) -> None:
        for model in self.models:
            model._analyse()