supaernova.steps.pae.tf.loss
[docs]
module
supaernova.steps.pae.tf.loss
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21 | import os
from typing import TYPE_CHECKING
os.environ["TF_USE_LEGACY_KERAS"] = "1"
os.environ["KERAS_BACKEND"] = "tensorflow"
os.environ["TF_DETERMINISTIC_OPS"] = "1"
os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0"
import tensorflow as tf
if TYPE_CHECKING:
from .tf import TFPAEModel
def WHuber(y_true, y_pred, *, model: "TFPAEModel"):
error = model._loss.input_mask * (y_true - y_pred) / model._loss.input_d_amp
cond = tf.abs(error) < model.loss_clip_delta
squared_loss = 0.5 * tf.square(error)
linear_loss = model.loss_clip_delta * (tf.abs(error) - 0.5 * model.loss_clip_delta)
return tf.reduce_mean(
tf.reduce_sum(tf.where(cond, squared_loss, linear_loss), axis=(-2, -1))
)
|