from typing import Optional
from .utils import *
from .Plate import flatten_tree, tensordict2tree
from .moments import torchdim_moments_mixin, named_moments_mixin
class AbstractImportanceSample():
def dump(self):
"""
Returns the Importance samples as dictionary of named tensors, where the name ``N`` corresponds to the different samples, and the other names correspond to plates.
Warning:
You shouldn't really need to use this method. e.g. if you're trying to compute moments you should directly use the dump method.
"""
return dim2named_dict(self.samples_flatdict)
def _moments_uniform_input(self, moms):
assert isinstance(moms, list)
result = []
for varnames, m in moms:
samples = tuple(self.samples_flatdict[varname] for varname in varnames)
result.append(m.from_samples(samples, self.Ndim))
return result
_moments = torchdim_moments_mixin
moments = named_moments_mixin
[docs]
class ImportanceSample(AbstractImportanceSample):
"""
alan.ImportanceSample()
Constructed by calling :func:`Sample.importance_sample <alan.Sample.importance_sample>`. Represents N joint samples in the latent space.
"""
def __init__(self, problem, samples_tree, Ndim):
"""
samples is tree-structured torchdim (as we might need to use it for extended).
"""
self.problem = problem
self.samples_tree = samples_tree
self.samples_flatdict = flatten_tree(samples_tree)
self.Ndim = Ndim
[docs]
def extend(self, extended_platesizes:dict[str, int], extended_inputs=None):
"""
Does prediction by:
* taking a posterior sample, represented by the ImportanceSample object.
* extending the plate sizes.
* sampling the extra latent variables from the prior.
It returns an :class:`.ExtendedImportanceSample` object.
Arguments:
extended_platesizes (dict[str, int]):
A dictionary mapping the platename to the extended platesize. Must be the same as or bigger than the platesizes in the underlying model.
extended_inputs (dict[str, torch.Tensor]):
If the model has any e.g. features given as ``inputs`` to :class:`.BoundPlate`, then the extended versions of these inputs must be provided here.
Note:
Won't work if P has any plated parameters, as these won't be extended.
"""
assert isinstance(extended_platesizes, dict)
if extended_inputs is None:
extended_inputs = {}
assert isinstance(extended_inputs, dict)
# If all_platesizes is missing some plates from self.problem.all_platedims,
# add them in without changing their sizes.
for name, dim in self.problem.all_platedims.items():
if name not in extended_platesizes:
extended_platesizes[name] = dim.size
# Check that extend_platesizes contains no extra plates.
assert set(extended_platesizes.keys()) == set(self.problem.all_platedims.keys())
# Create the new platedims from the platesizes.
extended_platedims = {name: Dim(name, size) for name, size in extended_platesizes.items()}
# Will need to add the extended inputs to the scope
all_inputs_params = tensordict2tree(self.problem.P.plate, named2dim_dict(extended_inputs, extended_platedims))
extended_sample = self.problem.P.plate.sample_extended(
sample=self.samples_tree,
name=None,
scope={},
inputs_params=all_inputs_params,
original_platedims=self.problem.all_platedims,
extended_platedims=extended_platedims,
active_extended_platedims=[],
Ndim=self.Ndim,
reparam=False,
original_data=self.problem.data,
)
return ExtendedImportanceSample(self.problem, extended_sample, self.Ndim, extended_platedims, extended_inputs)
[docs]
class ExtendedImportanceSample(AbstractImportanceSample):
"""
alan.ExtendedImportanceSample()
Constructed by calling :func:`ImportanceSample.extend <alan.ImportanceSample.extend>`. Represents N samples from the posterior over all latent variables, that has subsequently been extended.
"""
def __init__(self, problem, samples_tree, Ndim, extended_platedims, extended_inputs):
"""
samples is tree-structured torchdim (as we might need to use it for extended).
"""
self.problem = problem
self.samples_tree = samples_tree
self.samples_flatdict = flatten_tree(samples_tree)
self.Ndim = Ndim
self.extended_platedims = extended_platedims
self.extended_inputs = extended_inputs
[docs]
def predictive_ll(self, data:dict[str, Tensor]):
"""
Computes the average predictive log-likelihood for extended data.
Arguments:
data (dict[str, torch.Tensor]):
Extended data, provided as a dictionary mapping the variable name to a torch.Tensor. Note that this must be all data: both test and train.
"""
assert isinstance(data, dict)
# Convert data to torchdim
extended_data = named2dim_tensordict(self.extended_platedims, data)
original_data = flatten_tree(self.problem.data)
# If data is missing (i.e. not being extended), add it in from the original(non-extended) data.
for name, data_tensor in original_data.items():
if name not in data.keys():
data[name] = data_tensor
# Check that data contains no extra data names.
assert set(data.keys()) == set(original_data.keys())
# Will need to add the extended inputs to the scope
all_inputs_params = tensordict2tree(self.problem.P.plate, named2dim_dict(self.extended_inputs, self.extended_platedims))
lls_train, lls_all = self.problem.P.plate.predictive_ll(
sample=self.samples_tree,
name=None,
scope={},
inputs_params=all_inputs_params,
original_platedims=self.problem.all_platedims,
extended_platedims=self.extended_platedims,
original_data=original_data,
extended_data=extended_data,
)
# If we have lls for a variable in the training data, we should also have lls
# for it in the all (training+test) data.
assert set(lls_all.keys()) == set(lls_train.keys())
result = {}
for varname in lls_all:
ll_all = lls_all[varname]
ll_train = lls_train[varname]
dims_all = [dim for dim in ll_all.dims if dim is not self.Ndim]
dims_train = [dim for dim in ll_train.dims if dim is not self.Ndim]
assert len(dims_all) == len(dims_train)
if 0 < len(dims_all):
# Sum over plates
ll_all = ll_all.sum(dims_all)
ll_train = ll_train.sum(dims_train)
# Take mean over Ndim
result[varname] = logmeanexp_dims(ll_all - ll_train, (self.Ndim,))
return result