Skip to content

supaernova.configs.steps.data

[docs] module supaernova.configs.steps.data

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
# Copyright 2025 Patrick Armstrong


from typing import Self, Literal, ClassVar, Annotated
from pathlib import Path
from collections.abc import Callable

import numpy as np
from numpy import typing as npt
from astropy import cosmology as cosmo
from pydantic import (
    Field,
    PositiveInt,
    field_validator,
    model_validator,
)

from supaernova.analysis.spectra import SpectraPlot

from .steps import StepConfig, AbstractStepResult, AbstractStepAnalysis


class DataStepResult(AbstractStepResult):
    ind: "npt.NDArray[np.int32]"
    nspectra: "npt.NDArray[np.int32]"
    sn_name: "npt.NDArray[np.str_]"
    dphase: "npt.NDArray[np.float32]"
    redshift: "npt.NDArray[np.float32]"
    x0: "npt.NDArray[np.float32]"
    x1: "npt.NDArray[np.float32]"
    c: "npt.NDArray[np.float32]"
    MB: "npt.NDArray[np.float32]"
    hubble_residual: "npt.NDArray[np.float32]"
    luminosity_distance: "npt.NDArray[np.float32]"
    spectra_id: "npt.NDArray[np.str_]"
    phase: "npt.NDArray[np.float32]"
    wl_mask_min: "npt.NDArray[np.float32]"
    wl_mask_max: "npt.NDArray[np.float32]"
    amplitude: "npt.NDArray[np.float32]"
    sigma: "npt.NDArray[np.float32]"
    salt_flux: "npt.NDArray[np.float32]"
    wavelength: "npt.NDArray[np.float32]"
    mask: "npt.NDArray[np.int32]"
    time: "npt.NDArray[np.float32]"


class DataStepAnalysis(AbstractStepAnalysis):
    plot_spectra: SpectraPlot | list[SpectraPlot] | None = None
    plot_summary: SpectraPlot | list[SpectraPlot] | None = None


class DataStepConfig(StepConfig):
    # --- Class Variables ---
    id: ClassVar[str] = "data"

    # --- Required ---
    data_dir: Path
    meta: Path
    idr: Path
    mask: Path
    colourlaw: Path | None

    # --- Optional ---
    analysis: DataStepAnalysis = DataStepAnalysis.model_validate({})
    cosmological_model: str = "WMAP7"
    salt_model: str | Path = "salt2"
    min_phase: float = -10
    max_phase: float = 40
    seed: PositiveInt = 12345
    train_frac: Annotated[float, Field(ge=0, le=1)] = 0.75

    @model_validator(mode="after")
    def validate_paths(self) -> Self:
        self.data_dir = self.paths.resolve_path(
            self.data_dir, relative_path=self.paths.base
        )
        if not self.data_dir.exists():
            err = f"`data_dir` resolved to {self.data_dir}, which does not exist."
            self._raise(err)

        for field, ext in {"meta": ".csv", "idr": ".txt", "mask": ".txt"}.items():
            setattr(
                self,
                field,
                self.paths.resolve_path(
                    getattr(self, field), relative_path=self.data_dir
                ),
            )

            field_path: Path = getattr(self, field)

            if not field_path.exists():
                err = f"`{field}` resolved to {field_path}, which does not exist."
                self._raise(err)

            if field_path.suffix != ext:
                err = f"`{field}` resolved to {field_path}, which is not a {ext} file."
                self._raise(err)

        if self.colourlaw is not None:
            self.colourlaw = self.paths.resolve_path(
                self.colourlaw, relative_path=self.data_dir
            )
            if not self.colourlaw.exists():
                err = f"`colourlaw` resolved to {self.colourlaw}, which does not exist."
                self._raise(err)
        return self

    @field_validator("cosmological_model", mode="after")
    @classmethod
    def validate_cosmological_model(cls, value: str) -> str:
        if value not in cosmo.realizations.available:
            err = f"`cosmological_model` is {value} but must be one of {cosmo.realizations.available}"
            cls._raise(err)
        return value

    @field_validator("salt_model", mode="after")
    @classmethod
    def validate_salt_model(cls, value: str) -> str:
        if ("salt2" not in value) and ("salt3" not in value):
            err = f'`salt_model` is {value} but does not appear to be a salt2 or salt3 model, as it does not contain the string `"salt2"` or `"salt3"'
            cls._raise(err)
        return value

    @model_validator(mode="after")
    def validate_salt_model_path(self) -> Self:
        salt_path = self.paths.resolve_path(
            Path(self.salt_model), relative_path=self.paths.base
        )
        if salt_path.exists():
            self.salt_model = salt_path
        return self

    @model_validator(mode="after")
    def validate_max_phase(self) -> Self:
        if self.max_phase <= self.min_phase:
            err = f"`max_phase`: {self.max_phase} is not strictly greater than `min_phase`: {self.min_phase}"
            self._raise(err)
        return self


DataStepConfig.register_step()