Skip to content

supaernova.configs.input

[docs] module supaernova.configs.input

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
# Copyright 2025 Patrick Armstrong

from typing import Self

from pydantic import FilePath, computed_field, model_validator

from supaernova.steps import SNPAEStep
from supaernova.steps.pae import PAEStep
from supaernova.steps.data import DataStep
from supaernova.steps.nflow import NFlowStep
from supaernova.steps.posterior import PosteriorStep

from .steps import StepConfig
from .configs import SNPAEConfig
from .steps.pae import PAEStepConfig
from .steps.data import DataStepConfig
from .steps.nflow import NFlowStepConfig
from .steps.backends import Backend
from .steps.posterior import PosteriorStepConfig


class InputConfig(SNPAEConfig):
    data: DataStepConfig | None = None
    pae: PAEStepConfig[Backend] | None = None
    nflow: NFlowStepConfig[Backend] | None = None
    posterior: PosteriorStepConfig[Backend] | None = None

    data_step: DataStep | None = None
    pae_step: PAEStep[Backend] | None = None
    nflow_step: NFlowStep[Backend] | None = None
    posterior_step: PosteriorStep[Backend] | None = None

    @computed_field
    @property
    def step_configs(self) -> list[StepConfig]:
        return [
            step_config
            for step_config in [
                self.data,
                self.pae,
                self.nflow,
                self.posterior,
            ]
            if step_config is not None
        ]

    @computed_field
    @property
    def steps(self) -> list[SNPAEStep]:
        return [SNPAEStep.steps[step.id](step) for step in self.step_configs]

    @model_validator(mode="after")
    def validate_steps(self) -> Self:
        if len(self.step_configs) == 0:
            err = f"No steps have been defined! Please specify at least one of {list(SNPAEStep.steps.keys())}"
            self._raise(err)

        for step_config in self.step_configs:
            for required_step in step_config.required_steps:
                if getattr(self, required_step) is None:
                    err = f"{step_config.id} requires that {required_step} is run first, but {required_step} has not been defined!"
                    self._raise(err)
        return self

    def require(self, step_name: str) -> SNPAEStep:
        step = getattr(self, step_name + "_step")
        if step is None:
            err = f"{step_name} has not yet run"
            self._raise(err)
        return step

    def run(self) -> None:
        for step in self.steps:
            args = []
            kwargs = {
                required_step: self.require(required_step)
                for required_step in step.options.required_steps
            }
            step.setup(*args, **kwargs)
            step.run(*args, **kwargs)
            step.result(*args, **kwargs)
            step.analyse(*args, **kwargs)
            setattr(self, step.id + "_step", step)