1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142 | # Copyright 2025 Patrick Armstrong
from typing import Self, Literal, ClassVar, Annotated
from pathlib import Path
from collections.abc import Callable
import numpy as np
from numpy import typing as npt
from astropy import cosmology as cosmo
from pydantic import (
Field,
PositiveInt,
field_validator,
model_validator,
)
from supaernova.analysis.spectra import SpectraPlot
from .steps import StepConfig, AbstractStepResult, AbstractStepAnalysis
class DataStepResult(AbstractStepResult):
ind: "npt.NDArray[np.int32]"
nspectra: "npt.NDArray[np.int32]"
sn_name: "npt.NDArray[np.str_]"
dphase: "npt.NDArray[np.float32]"
redshift: "npt.NDArray[np.float32]"
x0: "npt.NDArray[np.float32]"
x1: "npt.NDArray[np.float32]"
c: "npt.NDArray[np.float32]"
MB: "npt.NDArray[np.float32]"
hubble_residual: "npt.NDArray[np.float32]"
luminosity_distance: "npt.NDArray[np.float32]"
spectra_id: "npt.NDArray[np.str_]"
phase: "npt.NDArray[np.float32]"
wl_mask_min: "npt.NDArray[np.float32]"
wl_mask_max: "npt.NDArray[np.float32]"
amplitude: "npt.NDArray[np.float32]"
sigma: "npt.NDArray[np.float32]"
salt_flux: "npt.NDArray[np.float32]"
wavelength: "npt.NDArray[np.float32]"
mask: "npt.NDArray[np.int32]"
time: "npt.NDArray[np.float32]"
class DataStepAnalysis(AbstractStepAnalysis):
plot_spectra: SpectraPlot | list[SpectraPlot] | None = None
plot_summary: SpectraPlot | list[SpectraPlot] | None = None
class DataStepConfig(StepConfig):
# --- Class Variables ---
id: ClassVar[str] = "data"
# --- Required ---
data_dir: Path
meta: Path
idr: Path
mask: Path
colourlaw: Path | None
# --- Optional ---
analysis: DataStepAnalysis = DataStepAnalysis.model_validate({})
cosmological_model: str = "WMAP7"
salt_model: str | Path = "salt2"
min_phase: float = -10
max_phase: float = 40
seed: PositiveInt = 12345
train_frac: Annotated[float, Field(ge=0, le=1)] = 0.75
@model_validator(mode="after")
def validate_paths(self) -> Self:
self.data_dir = self.paths.resolve_path(
self.data_dir, relative_path=self.paths.base
)
if not self.data_dir.exists():
err = f"`data_dir` resolved to {self.data_dir}, which does not exist."
self._raise(err)
for field, ext in {"meta": ".csv", "idr": ".txt", "mask": ".txt"}.items():
setattr(
self,
field,
self.paths.resolve_path(
getattr(self, field), relative_path=self.data_dir
),
)
field_path: Path = getattr(self, field)
if not field_path.exists():
err = f"`{field}` resolved to {field_path}, which does not exist."
self._raise(err)
if field_path.suffix != ext:
err = f"`{field}` resolved to {field_path}, which is not a {ext} file."
self._raise(err)
if self.colourlaw is not None:
self.colourlaw = self.paths.resolve_path(
self.colourlaw, relative_path=self.data_dir
)
if not self.colourlaw.exists():
err = f"`colourlaw` resolved to {self.colourlaw}, which does not exist."
self._raise(err)
return self
@field_validator("cosmological_model", mode="after")
@classmethod
def validate_cosmological_model(cls, value: str) -> str:
if value not in cosmo.realizations.available:
err = f"`cosmological_model` is {value} but must be one of {cosmo.realizations.available}"
cls._raise(err)
return value
@field_validator("salt_model", mode="after")
@classmethod
def validate_salt_model(cls, value: str) -> str:
if ("salt2" not in value) and ("salt3" not in value):
err = f'`salt_model` is {value} but does not appear to be a salt2 or salt3 model, as it does not contain the string `"salt2"` or `"salt3"'
cls._raise(err)
return value
@model_validator(mode="after")
def validate_salt_model_path(self) -> Self:
salt_path = self.paths.resolve_path(
Path(self.salt_model), relative_path=self.paths.base
)
if salt_path.exists():
self.salt_model = salt_path
return self
@model_validator(mode="after")
def validate_max_phase(self) -> Self:
if self.max_phase <= self.min_phase:
err = f"`max_phase`: {self.max_phase} is not strictly greater than `min_phase`: {self.min_phase}"
self._raise(err)
return self
DataStepConfig.register_step()
|