import torch as t
from typing import Optional
from functorch.dim import Dim
from .utils import *
from .Sampler import Sampler
from .dist import Dist
from .Group import Group
from .Data import Data
from .dist import Dist, _Dist, sample_gdt, datagroup
from .Timeseries import Timeseries
[docs]
class Plate():
"""
The key class used to define your model: all random variables, are defined within a ``Plate``.
An example plate definition:
.. code-block:: python
Plate(
a = Normal(0., 1.),
g = Group(
b = Normal('a', 1.),
c = Normal('b', 1.),
),
p = Plate(
d = Data(),
),
)
Everything in the plate is specified as a keyword argument (i.e. of the form ``name = thing``), where ``thing`` could be:
* a distribution (see :ref:`Distributions`).
* a :class:`.Group`.
* a sub-plate Plate.
* :class:`Data() <.Data>`.
Critically, the name becomes the name of that thing. So in the above example, we have normal random variables named ``a``, ``b``, and ``c``, a group named ``g``, a sub-plate named ``p``, and a random variable that will be associated with data, (see :class:`.Data`).
Note:
In standard Bayesian terminology, including a variable within a plate indicates that there is actually several of these variables. That's precisely how we're using sub-plates: each subplates (``p`` in the example) with have an assigned platesize, and we'll replicate each variable within the plate that number of times.
Note that we also use ``Plate`` at the top-layer, even though we only have one copy of the top-layer variables.
"""
def __init__(self, **kwargs):
#Finalise any dists
kwargs = {k: v.finalize(k) if isinstance(v, _Dist) else v for (k, v) in kwargs.items()}
self.grouped_prog = {}
self.flat_prog = {}
for k, v in kwargs.items():
if isinstance(v, Plate):
self.grouped_prog[k] = v
self.flat_prog[k] = v
else:
assert isinstance(v, (Group, Dist, Timeseries, Data))
if isinstance(v, Group):
group = v.prog
else:
group = {k: v}
self.grouped_prog[k] = {}
for gk, gv in group.items():
self.grouped_prog[k][gk] = gv
self.flat_prog[gk] = gv
#Error checking: plate/variable/group names aren't reserved
all_prog_names = self.all_prog_names()
for name in all_prog_names:
check_name(name)
#Error checking: no duplicate names.
dup_names = list_duplicates(all_prog_names)
if 0 != len(dup_names):
raise Exception(f"Plate has duplicate names {dup_names}.")
def grouped_get(self, d, groupname):
gv = self.grouped_prog[groupname]
if isinstance(gv, dict):
return {k: d.get(k) for k in gv}
else:
assert isinstance(gv, Plate)
return d[groupname]
def sample(
self,
name:Optional[str],
scope: dict[str, Tensor],
inputs_params: dict,
active_platedims:list[Dim],
all_platedims:dict[str, Dim],
groupvarname2Kdim:dict[str, Dim],
sampler:Sampler,
reparam:bool,
):
if name is not None:
active_platedims = [*active_platedims, all_platedims[name]]
scope = update_scope(scope, inputs_params)
sample = {}
for childname, prog in self.grouped_prog.items():
if isinstance(prog, dict):
if not datagroup(prog):
childsample = sample_gdt(
prog=prog if isinstance(prog, dict) else {name: prog},
scope=scope,
active_platedims=active_platedims,
K_dim=groupvarname2Kdim[childname],
groupvarname2Kdim=groupvarname2Kdim,
sampler=sampler,
reparam=reparam
)
for k, v in childsample.items():
sample[k] = childsample[k]
scope[k] = childsample[k]
else:
assert isinstance(prog, Plate)
platesample = prog.sample(
name=childname,
scope=scope,
inputs_params=inputs_params.get(childname),
active_platedims=active_platedims,
all_platedims=all_platedims,
groupvarname2Kdim=groupvarname2Kdim,
sampler=sampler,
reparam=reparam,
)
sample[childname] = platesample
scope[childname] = platesample
return sample
def sample_extended(
self,
sample:dict,
name:Optional[str],
scope:dict[str, Tensor],
inputs_params:dict,
original_platedims:dict[str, Dim],
extended_platedims:dict[str, Dim],
active_extended_platedims:list[Dim],
Ndim:Dim,
reparam:bool,
original_data:Optional[dict[str, Tensor]]):
if name is not None:
active_extended_platedims = [*active_extended_platedims, extended_platedims[name]]
scope = update_scope(scope, inputs_params)
for childname, childP in self.prog.items():
childsample = childP.sample_extended(
sample=sample.get(childname),
name=childname,
scope=scope,
inputs_params=inputs_params.get(childname),
original_platedims=original_platedims,
extended_platedims=extended_platedims,
active_extended_platedims=active_extended_platedims,
Ndim=Ndim,
reparam=reparam,
original_data=original_data[name] if name is not None else original_data, # only pass the data for the current plate
)
sample[childname] = childsample
scope = update_scope(scope, childsample)
return sample
def predictive_ll(
self,
sample:dict,
name:Optional[str],
scope:dict[str, Tensor],
inputs_params:dict,
original_platedims:dict[str, Dim],
extended_platedims:dict[str, Dim],
original_data:dict[str, Tensor],
extended_data:dict[str, Tensor]):
scope = update_scope(scope, inputs_params)
original_lls, extended_lls = {}, {}
for childname, childP in self.prog.items():
child_original_lls, child_extended_lls = childP.predictive_ll(
sample=sample.get(childname),
name=childname,
scope=scope,
inputs_params=inputs_params.get(childname),
original_platedims=original_platedims,
extended_platedims=extended_platedims,
original_data=original_data,
extended_data=extended_data
)
scope = update_scope(scope, sample.get(childname))
original_lls = {**original_lls, **child_original_lls}
extended_lls = {**extended_lls, **child_extended_lls}
return original_lls, extended_lls
def groupvarname2Kdim(self, K):
"""
Finds all the Groups/Dists in the program, and creates a
K-dimension for each.
"""
result = {}
for groupname, v in self.grouped_prog.items():
if isinstance(v, dict):
if not datagroup(v):
result[groupname] = Dim(f"K_{groupname}", K)
else:
assert isinstance(v, Plate)
result = {**result, **v.groupvarname2Kdim(K)}
return result
def all_prog_names(self):
"""
Returns all plate/group/variable names in the whole program.
Used to check that all names in the program are unique.
"""
result = []
for k, v in self.grouped_prog.items():
result.append(k)
if isinstance(v, dict):
if 2 <= len(v):
result = [*result, *v.keys()]
else:
assert isinstance(v, Plate)
result = [*result, *v.all_prog_names()]
return result
def varname2groupvarname_dist(self):
result = {}
for k, v in self.grouped_prog.items():
if isinstance(v, dict):
if not datagroup(v):
for gk, gv in v.items():
assert isinstance(gv, (Dist, Timeseries))
result[gk] = (k, gv)
else:
assert isinstance(v, Plate)
result = {**result, **v.varname2groupvarname_dist()}
return result
def varname2groupvarname(self):
return {varname: groupvarname for (varname, (groupvarname, _)) in self.varname2groupvarname_dist().items()}
def varname2dist(self):
return {varname: dist for (varname, (_, dist)) in self.varname2groupvarname_dist().items()}
def groupvarname2platenames(self):
return self._groupvarname2platenames([])
def _groupvarname2platenames(self, active_platenames:list[str]):
"""
Returns a dict mapping groupvarname (corresponding to K's) to the names of the active
plates for that groupvar.
Used for constructing Js for marginals + posterior sampling
"""
result = {}
for name, dgpt in self.grouped_prog.items():
if isinstance(dgpt, dict):
result[name] = active_platenames
else:
assert isinstance(dgpt, Plate)
active_platenames = [*active_platenames, name]
result = {**result, **dgpt._groupvarname2platenames(active_platenames)}
return result
def all_platenames(self):
result = []
for varname, dgpt in self.flat_prog.items():
if isinstance(dgpt, Plate):
result = [*result, *dgpt.all_platenames()]
else:
assert isinstance(dgpt, (Dist, Data, Timeseries))
return result
#Functions to update the scope
def update_scope(scope: dict[str, Tensor], samples_inputs_params:dict):
assert isinstance(scope, dict)
assert isinstance(samples_inputs_params, dict)
scope = {**scope}
for k, v in samples_inputs_params.items():
assert k not in scope
if not isinstance(v, dict):
assert isinstance(v, Tensor)
scope[k] = v
return scope
#### Functions to transform a flat dict to a tree, mirroring the structure of plate.
def empty_tree(plate: Plate):
assert isinstance(plate, Plate)
result = {}
for n, v in plate.flat_prog.items():
if isinstance(v, Plate):
result[n] = empty_tree(v)
return result
def all_platenames(plate: Plate):
"""
Extracts all platenames from a program
"""
assert isinstance(plate, Plate)
result = []
for n, v in plate.flat_prog.items():
if isinstance(v, Plate):
result = [*result, n, *all_platenames(v)]
return result
def tree_branches(tree:dict):
result = {}
for k, v in tree.items():
if isinstance(v, dict):
result[k] = v
else:
assert isinstance(v, Tensor)
return result
def tree_values(tree:dict):
result = {}
for k, v in tree.items():
if isinstance(v, Tensor):
result[k] = v
else:
assert isinstance(v, dict)
return result
def tensordict2tree(plate:Plate, tensor_dict:dict[str, Tensor]):
root = empty_tree(plate)
set_all_platenames = set(all_platenames(plate))
#For each tensor
for name, tensor in tensor_dict.items():
current_branch = root
#Pull out all the plate names
dimnames = [str(dim) for dim in generic_dims(tensor)]
platenames = set_all_platenames.intersection(dimnames)
#Go down tree, until you find the right branch.
while 0 < len(platenames):
next_plate = platenames.intersection(tree_branches(current_branch).keys())
assert 1==len(next_plate)
next_plate = list(next_plate)[0]
current_branch = current_branch[next_plate]
platenames.remove(next_plate)
current_branch[name] = tensor
return root
def flatten_tree(tree):
result = {}
for k, v in tree.items():
if isinstance(v, Tensor):
result[k] = v
else:
assert isinstance(v, dict)
result = {**result, **flatten_tree(v)}
return result
#def treemap(map_func, reduce_func):
# def inner(*trees):
# assert 1 <= len(trees)
#
# if any(isinstance(tree, dict) in trees):
# #If one argument is a dict, they're all dicts
# assert all(isinstance(tree, dict) in trees)
#
# #If they're all dicts, they have the same keys.
# keys0 = set(trees[0].keys())
# assert all(keys0 == set(tree.keys()) for tree in trees[1:])
#
# #If they're dicts, then you can't apply the function yet,
# #so keep recursing.
# result = {}
# for key in keys0:
# result[key] = treemap(f, *[tree[key] for tree in trees])
# return reduce_func(result)
# else:
# #If they aren't dicts finally apply the function.
# return map_func(*trees)
#
#
#def progmap(map_func, reduce_func):
# def inner(name, trees, active_platedims, **consts):
# #Push an extra plate, if not the top-layer plate (top-layer plate is signalled
# #by name=None.
# if name is not None:
# new_platedim = all_platedims[name]
# active_platedims = [*active_platedims, new_platedim]
#
# plate = trees[0]
# assert isinstance(plate, Plate)
#
# result = {}
# for k, v in plate.prog.items():
# if isinstance(v, Plate):
# assert isinstance(tree[k], (Plate, dict))
# result[k] = inner(k, [tree[k] for tree in trees], active_platedims, **consts)
# else:
# result[k] = map_func(k, v, **consts)
#
# return reduce_func(result, **consts)