Group

class alan.Group(**kwargs)[source]

A class used when defining the model in order to speed up inference.

Alan fundamentally works by drawing K samples of each latent variable, and considering all possible combinations of those variables. It sounds like this would be impossible, as there is K^n combinations, where n is the number of latent variables. Alan circumvents these difficulties using message passing-like algorithms to exploit conditional indepdencies, and get computation that is polynomial in K. However, the complexity can still be K^3 or K^4, and grouping variables helps to reduce that power.

In particular, consider two models

Slower model (K^3):

Plate(
    loc       = Normal(0., 1.),
    log_scale = Normal(0., 1.),
    d = Normal(loc, lambda log_scale: log_scale.exp())
    ...
)

This model has K^3 complexity, as we will need to compute the log-probability of all K samples of d for all K^2 samples of loc and log_scale. (There’s K samples of loc and K samples of log_scale, so K^2 combinations of samples of loc and log_scale). That K^3 complexity is excessive for this simple model. One solution would be to not consider all K^2 combinations of loc and log_scale, but instead consider only the K corresponding samples. That’s precisely what Group does:

Faster model (K^2):

Plate(
    g = Group(
        loc       = Normal(0., 1.),
        log_scale = Normal(0., 1.),
    ),
    d = Normal(loc, lambda log_scale: log_scale.exp())
    ...
)

The arguments to group are very similar to those in Plate, except that you can only have distributions, not sub-plates, sub-groups or Data.