Skip to content

supaernova.steps.pae.pae

[docs] module supaernova.steps.pae.pae

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
# Copyright 2025 Patrick Armstrong
from typing import TYPE_CHECKING, ClassVar, override

import numpy as np

from supaernova.steps.model import AbstractModelStep
from supaernova.configs.steps.data import DataStepResult

from .model import PAEModelStep

if TYPE_CHECKING:
    from logging import Logger

    from pydantic import PositiveFloat

    from supaernova.steps.data import DataStep
    from supaernova.configs.paths import PathConfig
    from supaernova.configs.globals import GlobalConfig
    from supaernova.configs.steps.pae import PAEStepConfig


class PAEStep[Backend: str](AbstractModelStep[Backend, PAEModelStep[Backend]]):
    # --- Class Variables ---
    id: ClassVar[str] = "pae"

    def __init__(self, config: "PAEStepConfig[Backend]") -> None:
        # --- Superclass Variables ---
        self.options: PAEStepConfig[Backend]
        self.config: GlobalConfig
        self.paths: PathConfig
        self.log: Logger
        self.force: bool
        self.verbose: bool
        super().__init__(config)

        # --- Previous Step Variables ---
        self.data: DataStep

        # --- Setup Variables ---
        self.train_data: list[DataStepResult]
        self.test_data: list[DataStepResult]
        self.val_data: list[DataStepResult]
        self.all_data: list[DataStepResult]
        self.n_models: int
        self.n_kfolds: int

    @override
    def _setup(self, *, data: "DataStep") -> None:
        super()._setup()

        # --- Previous Step Variables ---
        self.data = data

        # --- Models ---
        self.n_kfolds = self.data.n_kfolds
        self.log.debug(
            f"Training {self.n_models} models across {self.n_kfolds} kfolds."
        )
        if self.n_models > self.n_kfolds:
            self.log.warning(
                f"Data has {self.n_kfolds} kfolds, but {self.n_models} models were requested, some models will share the same training, testing, and validation data."
            )

        # --- Data ---
        train_data = self.data.train_data
        test_data = self.data.test_data
        all_data = self.data.data
        val_data = test_data

        if self.options.kfolds is None:
            self.kfolds = list(range(self.n_kfolds))
            # `(list * ((desired_length // actual_length) + 1))[:desired_length]`
            # Repeat `list` `(desired_length // actual_length) + 1` times, then take the first `desired_length` items
            self.train_data = (train_data * ((self.n_models // self.n_kfolds) + 1))[
                : self.n_models
            ]
            self.test_data = (test_data * ((self.n_models // self.n_kfolds) + 1))[
                : self.n_models
            ]
            self.val_data = (val_data * ((self.n_models // self.n_kfolds) + 1))[
                : self.n_models
            ]
        else:
            self.kfolds = self.options.kfolds
            self.train_data = train_data
            self.test_data = test_data
            self.val_data = val_data

        self.all_data = all_data

        for i, model in enumerate(self.models):
            model.setup(
                data=self.data,
                train_data=self.train_data[self.kfolds[i]],
                test_data=self.test_data[self.kfolds[i]],
                val_data=self.val_data[self.kfolds[i]],
                all_data=self.all_data,
            )


PAEStep.register_step()