Skip to content

supaernova.steps.posterior.tf.hmc

[docs] module supaernova.steps.posterior.tf.hmc

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
# Copyright 2025 Patrick Armstrong
import os

os.environ["TF_USE_LEGACY_KERAS"] = "1"
os.environ["KERAS_BACKEND"] = "tensorflow"
os.environ["TF_DETERMINISTIC_OPS"] = "1"
os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0"
import tensorflow as tf


class PosteriorHMCValue(tf.Module):
    def __init__(
        self,
        samples: tf.Variable,
        step_sizes_final: tf.Variable,
        is_accepted: tf.Variable,
        u_delta_av: tf.Variable,
        u_latents: tf.Variable,
        delta_av: tf.Variable,
        z_latents: tf.Variable,
        delta_m: tf.Variable,
        delta_p: tf.Variable,
    ) -> None:
        self.samples: tf.Variable = samples
        self.step_sizes_final: tf.Variable = step_sizes_final
        self.is_accepted: tf.Variable = is_accepted

        self.u_delta_av: tf.Variable = u_delta_av
        self.u_latents: tf.Variable = u_latents
        self.delta_av: tf.Variable = delta_av
        self.z_latents: tf.Variable = z_latents
        self.delta_m: tf.Variable = delta_m
        self.delta_p: tf.Variable = delta_p