from .utils import *
from .moments import torchdim_moments_mixin, named_moments_mixin
[docs]
class Marginals:
"""
alan.Marginals()
Constructed by calling :func:`Sample.marginals <alan.Sample.marginals>`. Represents the pre-computed marginals over the K samples for each latent variable.
"""
def __init__(
self,
samples:dict[str, Tensor],
weights:dict[tuple[str], Tensor],
all_platedims: dict[str, Dim],
varname2groupvarname: dict[str, tuple[str]]):
"""
samples and weights are flat dicts of torchdim Tensors.
But there's some subtlety as to the keys.
samples is indexed by a single varname.
weights is indexed by frozenset[groupvarname] (frozenset so its hashable, and we don't care about ordering).
That's because weights really depends on the K-dimensions, not the underlying variables.
Moreover, we could compute the joint marginal over multiple K-dimensions, not just one.
"""
self.samples = samples
self.weights = weights
self.all_platedims = all_platedims
self.varname2groupvarname = varname2groupvarname
def _moments_uniform_input(self, moms):
assert isinstance(moms, list)
result = []
for varnames, m in moms:
samples = tuple(self.samples[varname] for varname in varnames)
groupvarnames = frozenset([self.varname2groupvarname[varname] for varname in varnames])
weights = self.weights[groupvarnames]
result.append(m.from_marginals(samples, weights, self.all_platedims))
return result
_moments = torchdim_moments_mixin
moments = named_moments_mixin
def ess(self):
result = {}
set_all_platedims = set(self.all_platedims.values())
for (varnames, w) in self.weights.items():
Kdims = tuple(set(generic_dims(w)).difference(set_all_platedims))
assert 1 <= len(Kdims)
result[varnames] = 1/((w**2).sum(Kdims))
return result
def min_ess(self):
ess_dict = self.ess()
min_ess = [generic_min(ess) for ess in ess_dict.values()]
return min(min_ess)