Skip to content

supaernova.analysis.distribution

[docs] module supaernova.analysis.distribution

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
from typing import TYPE_CHECKING, Any
from pathlib import Path

import numpy as np
import pandas as pd

from .analysis import Plotter, AbstractPlot

if TYPE_CHECKING:
    from supaernova.configs.steps.steps import AbstractStepResult
    from supaernova.configs.steps.posterior import (
        PosteriorStepResult,
    )

    from .analysis import Axis, Figure


class DistributionPlot(AbstractPlot):
    labels: "dict[str | int, str | dict[str | int, str]] | None" = None
    mean: bool = False


class DistributionPlotter(Plotter):
    @staticmethod
    def prep_from_result(
        data: "AbstractStepResult", config: DistributionPlot
    ) -> pd.DataFrame:
        return pd.DataFrame({
            label: getattr(data, key) for (key, label) in (config.labels or {}).items()
        })

    @staticmethod
    def prep_from_array(data: "np.ndarray", config: DistributionPlot) -> pd.DataFrame:
        return pd.DataFrame({
            label: data[:, ind] for (ind, label) in (config.labels or {}).items()
        })

    @staticmethod
    def plot_corner(
        data: "AbstractStepResult | np.ndarray | list[AbstractStepResult] | list[np.ndarray] | dict[str, Any]",
        config: "DistributionPlot",
        *,
        fig: "Figure | None" = None,
        ax: "Axis | None" = None,
        force: bool = False,
        save: bool = True,
        **chain_kwargs: Any,
    ) -> tuple["Figure", "Axis"] | None:
        savepath = (config.savepath or Path()) / f"{config.name}.{config.ext}"
        if savepath.exists() and not force:
            return None

        labels = None
        if isinstance(data, dict):
            labels = list(data.keys())
            data = list(data.values())

        if not isinstance(data, list):
            data = [data]
        if config.mean:
            chains = {
                "mean": DistributionPlotter.prep_from_array(
                    np.mean(data, axis=0), config
                )
            }
        else:
            chains = []
            config_labels = config.labels
            for i, d in enumerate(data):
                if labels is not None:
                    config.labels = config_labels[labels[i]]
                if isinstance(d, np.ndarray):
                    chain = DistributionPlotter.prep_from_array(d, config)
                else:
                    chain = DistributionPlotter.prep_from_result(d, config)
                chains.append(chain)
            if labels is None:
                labels = range(len(chains))
            chains = {labels[i]: chain for (i, chain) in enumerate(chains)}

        fig, ax = Plotter.corner(chains, fig=fig, ax=ax, chain_kwargs=chain_kwargs)

        if save:
            fig = Plotter.save(fig, savepath)
            Plotter.close(fig, ax)
            return None
        return fig, ax