Skip to content

supaernova.configs.steps.pae.tf

[docs] module supaernova.configs.steps.pae.tf

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
import os
from typing import Any, Concatenate, cast, override
from functools import cached_property
from collections.abc import Callable

from pydantic import PositiveFloat, computed_field

os.environ["TF_USE_LEGACY_KERAS"] = "1"
os.environ["KERAS_BACKEND"] = "tensorflow"
os.environ["TF_DETERMINISTIC_OPS"] = "1"
os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0"
import tensorflow as tf
from tensorflow import keras as ks

from supaernova.steps.pae.tf import (
    loss as snpae_losses,
)
from supaernova.configs.steps import ConfigInputObject, validate_object

from .model import PAEModelConfig

ActivationObject = Callable[[tf.Tensor], tf.Tensor]
RegulariserObject = type[ks.regularizers.Regularizer] | Callable[[tf.Tensor], tf.Tensor]
SchedulerObject = (
    type[ks.optimizers.schedules.LearningRateSchedule]
    | Callable[[Concatenate[int | tf.Tensor, ...]], tf.Tensor]
)
OptimiserObject = type[ks.optimizers.Optimizer]
LossObject = type[ks.losses.Loss] | Callable[[tf.Tensor, tf.Tensor], tf.Tensor]


def validate_activation(activation: ConfigInputObject[ActivationObject]):
    return validate_object(activation, dummy_obj=tf.nn.relu, mod=tf.nn)


def validate_kernel_regulariser(
    kernel_regulariser: ConfigInputObject[RegulariserObject],
) -> RegulariserObject:
    return validate_object(
        kernel_regulariser, dummy_obj=ks.regularizers.Regularizer, mod=ks.regularizers
    )


def validate_scheduler(
    scheduler: ConfigInputObject[SchedulerObject],
) -> SchedulerObject:
    return validate_object(
        scheduler,
        dummy_obj=ks.optimizers.schedules.LearningRateSchedule,
        mod=ks.optimizers.schedules,
    )


def validate_optimiser(
    optimiser: ConfigInputObject[OptimiserObject],
):
    return validate_object(
        optimiser, dummy_obj=ks.optimizers.Optimizer, mod=ks.optimizers
    )


def validate_loss(
    loss: ConfigInputObject[LossObject],
):
    err = f"Could not validate loss: {loss}:\n"
    for dummy_obj in (ks.losses.Loss, ks.losses.mae):
        for mod in (ks.losses, snpae_losses):
            try:
                return validate_object(loss, dummy_obj=dummy_obj, mod=mod)
            except ValueError as e:
                err += f"{e}\n"
    raise ValueError(err)


def get_loss(
    loss_fn: Callable[[tf.Tensor, tf.Tensor], tf.Tensor],
) -> type[ks.losses.Loss]:
    @ks.utils.register_keras_serializable("SuPAErnova")
    class CustomLoss(ks.losses.Loss):
        @override
        def call(self, y_true: tf.Tensor, y_pred: tf.Tensor) -> tf.Tensor:
            return loss_fn(y_true, y_pred, model=self.model)

    return CustomLoss


class TFPAEModelConfig(PAEModelConfig):
    # --- Training ---
    activation: ConfigInputObject[ActivationObject]

    @computed_field
    @cached_property
    def activation_fn(self) -> ActivationObject:
        return validate_activation(self.activation)

    kernel_regulariser: ConfigInputObject[RegulariserObject] | None = None
    kernel_regulariser_penalty: PositiveFloat | None = None

    @computed_field
    @cached_property
    def kernel_regulariser_cls(self) -> type[ks.regularizers.Regularizer] | None:
        if self.kernel_regulariser is None:
            return None
        regulariser = validate_kernel_regulariser(self.kernel_regulariser)
        if isinstance(regulariser, type):
            return regulariser

        class CustomRegulariser(ks.regularizers.Regularizer):
            @override
            def __init__(self, *args: Any, **kwargs: Any) -> None:
                super().__init__(*args, **kwargs)

            @override
            def __call__(self, x: tf.Tensor) -> tf.Tensor:
                return regulariser(x)

        return CustomRegulariser

    scheduler: ConfigInputObject[SchedulerObject]

    @computed_field
    @cached_property
    def scheduler_cls(self) -> type[ks.optimizers.schedules.LearningRateSchedule]:
        scheduler = validate_scheduler(self.scheduler)
        if isinstance(scheduler, type):
            return scheduler

        class CustomScheduler(ks.optimizers.schedules.LearningRateSchedule):
            @override
            def __init__(
                self,
                *,
                initial_learning_rate: float,
                decay_steps: int,
                decay_rate: float,
            ) -> None:
                self.initial_learning_rate: float = initial_learning_rate
                self.decay_steps: int = decay_steps
                self.decay_rate: float = decay_rate

            @override
            def __call__(self, step: int | tf.Tensor) -> tf.Tensor:
                return scheduler(
                    step,
                    initial_learning_rate=self.initial_learning_rate,
                    decay_steps=self.decay_steps,
                    decay_rate=self.decay_rate,
                )

        return CustomScheduler

    optimiser: ConfigInputObject[OptimiserObject]

    @computed_field
    @cached_property
    def optimiser_cls(self) -> type[ks.optimizers.Optimizer]:
        return cast(
            "type[ks.optimizers.Optimizer]",
            cast("object", validate_optimiser(self.optimiser)),
        )

    loss: ConfigInputObject[LossObject]

    @computed_field
    @cached_property
    def loss_cls(self) -> type[ks.losses.Loss]:
        loss = validate_loss(self.loss)

        if isinstance(loss, type):
            loss = loss()

        return get_loss(loss)