Source code for alan.Group

import torch as t

from typing import Optional
from .dist import _Dist, sample_gdt
from .utils import *
from .Sampler import Sampler


[docs] class Group(): """ A class used when defining the model in order to speed up inference. Alan fundamentally works by drawing K samples of each latent variable, and considering all possible combinations of those variables. It sounds like this would be impossible, as there is K^n combinations, where n is the number of latent variables. Alan circumvents these difficulties using message passing-like algorithms to exploit conditional indepdencies, and get computation that is polynomial in K. However, the complexity can still be K^3 or K^4, and grouping variables helps to reduce that power. In particular, consider two models Slower model (K^3): .. code-block:: python Plate( loc = Normal(0., 1.), log_scale = Normal(0., 1.), d = Normal(loc, lambda log_scale: log_scale.exp()) ... ) This model has K^3 complexity, as we will need to compute the log-probability of all K samples of ``d`` for all K^2 samples of ``loc`` and ``log_scale``. (There's K samples of ``loc`` and K samples of ``log_scale``, so K^2 combinations of samples of ``loc`` and ``log_scale``). That K^3 complexity is excessive for this simple model. One solution would be to not consider all K^2 combinations of ``loc`` and ``log_scale``, but instead consider only the K corresponding samples. That's precisely what ``Group`` does: Faster model (K^2): .. code-block:: python Plate( g = Group( loc = Normal(0., 1.), log_scale = Normal(0., 1.), ), d = Normal(loc, lambda log_scale: log_scale.exp()) ... ) The arguments to group are very similar to those in :class:`.Plate`, except that you can only have distributions, not sub-plates, sub-groups or :class:`.Data`. """ def __init__(self, **kwargs): #Groups can only contain Dist, not Plates/Timeseries/Data/other Groups. for varname, dist in kwargs.items(): if not isinstance(dist, (_Dist)): raise Exception("{varname} in a Group should be a Dist or Timeseries, but is actually {type(dist)}") if len(kwargs) < 2: raise Exception("Groups only make sense if they have two or more random variables, but this group only has {len(kwargs)} random variables") self.prog = {varname: dist.finalize(varname) for (varname, dist) in kwargs.items()}