from typing import Optional
import torch as t
import torch.nn as nn
from .utils import *
from .Sampler import Sampler, PermutationSampler
from .Plate import tensordict2tree, Plate, flatten_tree
from .Stores import BufferStore, ParameterStore, ModuleStore
from .moments import moments_func2name
from .Param import OptParam, QEMParam
from .conversions import conversion_dict
from .Timeseries import Timeseries
def named2torchdim_flat2tree(flat_named:dict, all_platedims, plate):
flat_torchdim = named2dim_dict(flat_named, all_platedims)
return tensordict2tree(plate, flat_torchdim)
def expand_named(x, names:list[str], all_platesizes:dict[str, int]):
names_x = non_none_names(x)
for name_x in names_x:
if name_x not in all_platesizes:
raise Exception(f"{name_x} is specified on a parameter, but is not given in all_platesizes")
if x.size(name_x) == all_platesizes[name_x]:
raise Exception(f"{name_x} is given as length {all_platesizes[name_x]} on all_platesizes, but there's a parameter where this dimension is size {x.size(name_x)}")
for name in names:
if name not in all_platesizes:
raise Exception(f"{name} is a plate dimension, but is not given in all_platesizes")
extra_platenames = list(set(names).difference(names_x))
extra_plate_shape = [all_platesizes[name] for name in extra_platenames]
return x.expand(*extra_plate_shape, *x.shape).contiguous().refine_names(*names, *x.names)
def non_none_names(x):
return [name for name in x.names if name is not None]
[docs]
class BoundPlate(nn.Module):
"""
Binds a Plate representing P or Q to platesizes, and initializes parameters specified by OptParam or QEMParam.
Arguments:
plate (Plate):
The plate specifying P or Q.
all_platesizes (dict[str, int]):
Dictionary mapping string platename to integer platesize.
Keyword Arguments:
inputs (dict[str, (named) torch.Tensor]):
Dictionary mapping string input name to input value, as a named ``torch.Tensor``. This is used to represent e.g. features that the model is conditioned on, but that aren't sampled from the model. Note that
extra_opt_params (dict[str, (named) torch.Tensor]):
Dictionary mapping string parameter name to initial parameter value, as a named ``torch.Tensor``. Usually you'd specify parameters to be optimized using OptParam. But the OptParam approach is slightly restictive, as an OptParam can only be used as a direct argument to a distribution (e.g. `` a = Normal(OptParam(0.), 1.)``, whereas an parameter given here can be used anywhere in the program.
Inputs or extra_opt_params are specified as named tensors, where the names corresond to the plates (as with data).
"""
def __init__(self, plate: Plate, all_platesizes:dict[str, int], inputs=None, extra_opt_params=None):
super().__init__()
#A tensor that e.g. moves to GPU when we call `problem.to(device='cuda')`.
self.register_buffer("_device_tensor", t.zeros(()))
assert isinstance(plate, Plate)
self.plate = plate
if all_platesizes is None:
all_platesizes = {}
assert isinstance(all_platesizes, dict)
for platename in plate.all_platenames():
if platename not in all_platesizes:
raise Exception(f"Every plate must have a platesize specified in all_platesizes, but {platename} doesn't have a specified size")
self.all_platesizes = all_platesizes
if inputs is None:
inputs = {}
assert isinstance(inputs, dict)
if extra_opt_params is None:
extra_opt_params = {}
assert isinstance(extra_opt_params, dict)
#Check inputs are all named tensors
inputs_extra_opt_params = {**inputs, **extra_opt_params}
for k, v in inputs_extra_opt_params.items():
if not isinstance(v, t.Tensor):
raise Exception("`inputs` and `extra_opt_params` must be provided as a plain named tensor, but {k} is of type {type(v)}")
#Check all dimensions used in inputs/extra_opt_params are present in all_platesizes, and match
for k, v in inputs_extra_opt_params.items():
for name in v.names:
if name is not None:
if name not in all_platesizes:
raise Exception("Dimension name {name} used on input/extra_opt_param {k}, but not provided in all_platesizes")
if v.size(name) != all_platesizes[name]:
raise Exception("Dimension mismatch for input {k} along dimension {name}; all_platesizes gives {all_platesizes[name]}, while {k} is {v.size(name)}")
#Check that timeseries inits are in the right place
check_timeseries(plate)
#Check that inputs/extra_log_params are used in a place that makes sense with regard to plates.
groupvarname2platenames = self.plate.groupvarname2platenames()
varname2groupvarname_dist = self.plate.varname2groupvarname_dist()
for varname, (groupvarname, dist) in varname2groupvarname_dist.items():
for argname in dist.all_args:
if argname in inputs_extra_opt_params:
dist_platenames = groupvarname2platenames[groupvarname]
input_extra_opt_param_platenames = non_none_names(inputs_extra_opt_params[argname])
if not set(input_extra_opt_param_platenames).issubset(dist_platenames):
raise Exception(f"{argname} is used on {varname}, which has plates {dist_platenames}, but {argname} has plates {input_extra_opt_param_platenames}")
###################################
#### Setting up QEM/Opt Params ####
###################################
opt_paramname2tensor = extra_opt_params
self.opt_paramname2trans = {paramname: (lambda x: x) for paramname in opt_paramname2tensor}
#List of varname
self.qem_list_varname = []
#List of conversion
self.qem_list_conversion = []
#List of lists of rmkeys; outer list corresponds to random variables.
self.qem_list_rmkeys = []
#Flat list of rmkeys.
self.qem_flat_list_rmkeys = []
#Dict of mapping meanname -> moving average moment as a tensor.
qem_meanname2mom = {}
#Dict mapping paramname to conventional parameter as a tensor
qem_params = {}
#Dict mapping varname, distargname 2 paramname
self.qem_varname_distargname2paramname = {}
#We need meanname because we need to put the tensors in a BufferStore, which requires a dict as input.
self.qem_rmkey2meanname = {}
self.qem_meanname2rmkey = {}
for varname, (groupvarname, dist) in varname2groupvarname_dist.items():
platenames = groupvarname2platenames[groupvarname]
if not dist.qem_dist:
#Not a QEM distribution, so may contain opt_paramname2tensor.
for paramname, (distargname, param) in dist.opt_qem_params.items():
if paramname in opt_paramname2tensor:
raise Exception("OptParam is trying to add parameter named {paramname}, but there's already a parameter with this name")
opt_paramname2tensor[paramname] = expand_named(param.init, platenames, all_platesizes)
self.opt_paramname2trans[paramname] = param.trans
else:
#A QEM distribution, so does not contain opt_params.
self.qem_list_varname.append(varname)
conversion = conversion_dict[dist.dist]
self.qem_list_conversion.append(conversion)
#Sufficient statistics for a distribution, specified in the form of
#rmkeys: tuple[tuple[varname], RawMoment]
#so we can pass it directly to e.g. marginals.moments.
rmkeys = [((varname,), mom) for mom in conversion.sufficient_stats]
self.qem_flat_list_rmkeys = [*self.qem_flat_list_rmkeys, *rmkeys]
self.qem_list_rmkeys.append(rmkeys)
#Expand the initial conventional parameters provided to the distribution
#and use them to compute the initial mean parameters, using conversion.
init_conv_dict = {}
for paramname, (distargname, param) in dist.opt_qem_params.items():
expanded_conv_param = expand_named(param.init, platenames, all_platesizes)
qem_params[paramname] = expanded_conv_param
init_conv_dict[distargname] = expanded_conv_param
init_means = conversion.conv2mean(**init_conv_dict)
for i, rmkey in enumerate(rmkeys):
_, rawmoment = rmkey
meanname = f"{varname}_{moments_func2name[rawmoment]}"
self.qem_rmkey2meanname[rmkey] = meanname
self.qem_meanname2rmkey[meanname] = rmkey
#Put these initial mean parameters into the critical qem_moving_average_moments
#dict.
for rmkey, init_mean in zip(rmkeys, init_means):
meanname = self.qem_rmkey2meanname[rmkey]
qem_meanname2mom[meanname] = init_mean
#conversion.mean2conv produces a dict:
#distargname -> Tensor.
#we need to convert varname, distargname -> paramname
distargname2paramname = {}
for paramname, (distargname, param) in dist.opt_qem_params.items():
self.qem_varname_distargname2paramname[(varname, distargname)] = paramname
################################################
#### Finished setting up Opt/QEM Params! ####
#### Now, assign stuff to Param/BufferStore ####
#### so it is properly registered on device ####
################################################
self._inputs = BufferStore(inputs)
self._opt_params = ParameterStore(opt_paramname2tensor)
self._qem_params = BufferStore(qem_params) #dict mapping paramname -> conventional parameter.
self._qem_means = BufferStore(qem_meanname2mom) #dict mapping meanname -> moving average mean parameter.
self._dists = ModuleStore(plate.varname2dist())
###################################
#### A bit more error checking ####
###################################
#Error checking: input, param names aren't reserved
input_param_names = list(self.inputs_params_flat_named().keys())
set_input_param_names = set(input_param_names)
for name in input_param_names:
check_name(name)
#Error checking: no overlap between names in inputs and params.
if len(input_param_names) != len(set_input_param_names):
raise Exception(f"BoundPlate has overlapping names in inputs, opt_params, and/or qem_params")
#Error checking: no overlap between names in program and in inputs or params.
prog_names = self.plate.all_prog_names()
prog_input_param_names_overlap = set_input_param_names.intersection(prog_names)
if 0 != len(prog_input_param_names_overlap):
raise Exception(f"The program in BoundPlate has plate/random variable names that overlap with the inputs/params. Specifically {prog_inputs_param_names_overlap}.")
#Check that all the dependencies make sense by sampling.
self.sample()
@property
def device(self):
return self._device_tensor.device
[docs]
def qem_params(self):
"""
Returns a dictionary of the parameters learned using QEM.
"""
return self._qem_params.to_dict()
[docs]
def qem_means(self):
"""
Returns a dictionary of the exponential moving average moments used for QEM.
"""
return self._qem_means.to_dict()
[docs]
def opt_params(self):
"""
Returns a dictionary of the parameters learned by optimization.
"""
result = {}
for paramname, tensor in self._opt_params.to_dict().items():
result[paramname] = self.opt_paramname2trans[paramname](tensor)
return result
def _update_qem_convparams(self):
"""
Converts moving averages in self.qem_moving_average_moments to a flat dict mapping
paramname -> conventional parameter
"""
meanname2mom = self.qem_means()
for varname, conversion, rmkeys in zip(self.qem_list_varname, self.qem_list_conversion, self.qem_list_rmkeys):
means = [meanname2mom[self.qem_rmkey2meanname[rmkey]] for rmkey in rmkeys]
conv_dict = conversion.mean2conv(*means)
for distargname, tensor in conv_dict.items():
paramname = self.qem_varname_distargname2paramname[varname, distargname]
getattr(self._qem_params, paramname).copy_(tensor)
def _update_qem_moving_avg(self, lr, sample, computation_strategy):
rmkey_list = self.qem_flat_list_rmkeys
if 0 < len(rmkey_list):
new_moment_list = sample.moments(rmkey_list, computation_strategy=computation_strategy)
for rmkey, new_moment in zip(rmkey_list, new_moment_list):
meanname = self.qem_rmkey2meanname[rmkey]
tensor = getattr(self._qem_means, meanname)
assert set(non_none_names(tensor)) == set(non_none_names(new_moment))
new_moment = new_moment.align_as(tensor)
tensor.mul_(1-lr).add_(new_moment, alpha=lr)
def _update_qem_params(self, lr, sample, computation_strategy):
self._update_qem_moving_avg(lr, sample, computation_strategy)
self._update_qem_convparams()
def inputs_params_flat_named(self):
"""
Returns a dict mapping from str -> named tensor
"""
return {**self.inputs(), **self.opt_params(), **self.qem_params()}
def inputs_params(self, all_platedims:dict[str, Dim]):
return named2torchdim_flat2tree(self.inputs_params_flat_named(), all_platedims, self.plate)
def sample_extended(
self,
sample:dict,
name:Optional[str],
scope:dict[str, Tensor],
inputs_params:dict,
original_platedims:dict[str, Dim],
extended_platedims:dict[str, Dim],
active_original_platedims:list[Dim],
active_extended_platedims:list[Dim],
Ndim:Dim,
reparam:bool,
original_data:Optional[dict[str, Tensor]],
extended_data:Optional[dict[str, Tensor]]):
scope = {**scope, **self.inputs_params_flat_named()}
return self.plate.sample_extended(
sample,
name,
scope,
inputs_params,
original_platedims,
extended_platedims,
active_original_platedims,
active_extended_platedims,
Ndim,
reparam,
original_data,
extended_data)
def _sample(self, K: int, reparam:bool, sampler:Sampler, all_platedims:dict[str, Dim]):
"""
Internal sampling method.
Returns:
globalK_sample: sample with different K-dimension for each variable.
logPQ: log-prob.
"""
assert isinstance(K, int)
assert isinstance(reparam, bool)
assert issubclass(sampler, Sampler)
#assert isinstance(next(iter(all_platedims.values())), Dim)
groupvarname2Kdim = self.plate.groupvarname2Kdim(K)
sample = self.plate.sample(
name=None,
scope={},
inputs_params=self.inputs_params(all_platedims),
active_platedims=[],
all_platedims=all_platedims,
groupvarname2Kdim=groupvarname2Kdim,
sampler=sampler,
reparam=reparam,
)
return sample, groupvarname2Kdim
[docs]
def sample(self):
"""
Returns a single sample from the model, as a flat dictionary of named tensors, where the names correspond to plate dimensions.
"""
all_platedims = {platename: Dim(platename, size) for (platename, size) in self.all_platesizes.items()}
set_platedims = list(all_platedims.values())
torchdim_tree_withK, _ = self._sample(1, False, PermutationSampler, all_platedims)
torchdim_flatdict_withK = flatten_tree(torchdim_tree_withK)
torchdim_flatdict_noK = {}
for k, v in torchdim_flatdict_withK.items():
K_dims = list(set(generic_dims(v)).difference(set_platedims))
v = v.order(K_dims)
v = v.squeeze(tuple(range(len(K_dims))))
torchdim_flatdict_noK[k] = v.detach()
return dim2named_dict(torchdim_flatdict_noK)
def groupvarname2platenames(self):
return self.plate.groupvarname2platenames()
def varname2groupvarname(self):
return self.plate.varname2groupvarname()
def check_timeseries(top_plate:Plate):
assert isinstance(top_plate, Plate)
for k, v in top_plate.grouped_prog.items():
if isinstance(v, Plate):
check_timeseries_inner(v, top_plate)
else:
assert isinstance(v, dict)
def check_timeseries_inner(current_plate:Plate, upper_plate:Plate):
assert isinstance(current_plate, Plate)
assert isinstance(upper_plate, Plate)
upper_varname2groupvarname = upper_plate.varname2groupvarname()
for k, v in current_plate.grouped_prog.items():
if isinstance(v, dict):
#Gather timeseries inits.
init_groupnames = []
for gk, gv in v.items():
if isinstance(gv, Timeseries):
init_varname = gv.init
if init_varname not in upper_plate.flat_prog:
raise Exception("Timeseries must have an initializer that is present in the immediate parent plate. However, the initializer for timeseries {gk}, i.e. {init_varname} doesn't seem to be present in the immediate parent plate.")
init_groupname = upper_varname2groupvarname[init_varname]
init_groupnames.append(upper_varname2groupvarname[gv.init])
#Check all init_groupnames are the same
if 1 <= len(init_groupnames):
tsg0 = init_groupnames[0]
for tsg in init_groupnames[1:]:
if tsg != tsg0:
raise Exception(f"The initializers for a plate must be grouped in the same way as the timeseries themselves. However, the initializers for timeseries {list(v.keys())}, on group {k} seemed to be grouped differently")
else:
assert isinstance(v, Plate)
check_timeseries_inner(v, current_plate)