Skip to content

supaernova.steps.pae.tf.loss

[docs] module supaernova.steps.pae.tf.loss

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
import os
from typing import TYPE_CHECKING

os.environ["TF_USE_LEGACY_KERAS"] = "1"
os.environ["KERAS_BACKEND"] = "tensorflow"
os.environ["TF_DETERMINISTIC_OPS"] = "1"
os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0"
import tensorflow as tf

if TYPE_CHECKING:
    from .tf import TFPAEModel


def WHuber(y_true, y_pred, *, model: "TFPAEModel"):
    error = model._loss.input_mask * (y_true - y_pred) / model._loss.input_d_amp
    cond = tf.abs(error) < model.loss_clip_delta
    squared_loss = 0.5 * tf.square(error)
    linear_loss = model.loss_clip_delta * (tf.abs(error) - 0.5 * model.loss_clip_delta)
    return tf.reduce_mean(
        tf.reduce_sum(tf.where(cond, squared_loss, linear_loss), axis=(-2, -1))
    )