import torch.nn as nn
from .dist import _Dist, sample_gdt
from .utils import *
from .Sampler import Sampler
#Checks to Timeseries init.
#if 0 == len(active_platedims):
# raise Exception(f"Timeseries can't be in the top-layer plate, as there's no platesize at the top")
#if name not in self.trans.all_args:
# raise Exception(f"The timeseries transition distribution for {name} must have some dependence on the previous timestep; you get that by including {name} as an argument in the transition distribution.")
[docs]
class Timeseries(nn.Module):
"""
In progress!!!!
See `examples/timeseries.py`
Arguments:
init (str):
string, representing the initial state as a random variable. This random variable must have been sampled in the immediately above plate.
trans (Dist):
transition distribution.
As an example:
.. code-block:: python
Plate(
ts_init = Normal(0., 1.),
T = Plate(
ts = Timeseries('ts_init', Normal(lambda ts: 0.9*ts, 0.1)),
)
)
In the exmplae:
* ``T`` is the plate (i.e. ``all_platesizes['T']``) is the length of the timeseries. Note that this is a slight abuse of the term "Plate", which is usually only used to refer to independent variables.
* ``ts`` is the name of the timeseries random variable itself.
* ``Normal(lambda ts: 0.9*ts, 0.1)`` is the transition distribution. Note that it refers to the previous step of itself using the timeseries name itself, ``ts``, as an argument.
* ``ts_init`` is the initial state. Must be a string representing a random variable in the previous plate.
Non-split implementation notes:
* Non-split log_PQ_plate returns a K_ts_init tensor.
* Splitting log_PQ_plate:
- Uses a backward pass, so at the start of the backward pass, we sum from the back.
- At the start of the backward pass, log_PQ_plate takes one unusual input: initial timeseries state, with dimension K_ts. If initial timeseries state is provided as a kwarg, we ignore Timeseries.init.
- log_PQ_plate returns K_ts dimensional tensor, resulting from summing all the way from the back to the start of the split.
- The next split takes two unusual arguments: the initial state, and the log_pq from the last split evaluated by the backward pass.
Note:
You can't currently split along a timeseries dimension (and you may never be able to).
Note:
OptParam and QEMParam are currently banned in timeseries.
"""
def __init__(self, init, trans):
super().__init__()
self.qem_dist = False
self.is_timeseries = True
if not isinstance(init, str):
raise Exception(f"the first / `init` argument in a Timeseries should be a string, representing a variable name in the above plate")
if not isinstance(trans, _Dist):
raise Exception("the second / `trans` argument in a Timeseries should be a distribution")
if trans.sample_shape != t.Size([]):
raise Exception("sample_shape on the transition distribution must not be set; if you want a sample_shape, it needs to be on the initial state")
self.init = init
self.trans = trans.finalize(None)
assert not self.trans.qem_dist
#Will include own name, but that'll be eliminated in the first step of sample_gdt
self.all_args = [init, *self.trans.all_args]
@property
def opt_qem_params(self):
return self.trans.opt_qem_params
def sample(self, scope, reparam: bool, active_platedims:list[Dim], K_dim:Dim, timeseries_perm):
assert 0 <= len(active_platedims)
(other_platedims, T_dim) = (active_platedims[:-1], active_platedims[-1])
#Set previous state equal to initial state.
prev_state = scope[self.init]
#Check that prev_state has the right dimensions
if set(prev_state.dims) != set([K_dim, *other_platedims]):
raise Exception(f"Initial state, {self.init}, doesn't have the right dimensions for timeseries {name}; the initial state must be defined one step up in the plate heirarchy")
sample_timesteps = []
for time in range(T_dim.size):
#new scope, where we select out the time'th timestep for any tensor with a time dimension.
#all these variables have already been resampled.
timeseries_scope = {}
for k, v in scope.items():
if T_dim in set(generic_dims(v)):
v = v.order(T_dim)[time]
timeseries_scope[k] = v
#Put previous timestep for self into scope
timeseries_scope['prev'] = prev_state
#sample the next timestep
sample_timestep = self.trans.sample(timeseries_scope, reparam, other_platedims, K_dim, None)
sample_timesteps.append(sample_timestep)
#Permute this timestep, ready for being used as prev_state.
if timeseries_perm is not None:
timestep_perm = timeseries_perm.order(T_dim)[time]
sample_timestep = sample_timestep.order(K_dim)[timestep_perm, ...][K_dim]
prev_state = sample_timestep
return t.stack(sample_timesteps, 0)[T_dim]
def log_prob(self, sample, scope:dict, T_dim:Dim, Kinit_dim:Dim, K_dim:Dim):
assert isinstance(scope, dict)
assert isinstance(sample, Tensor)
assert isinstance(T_dim, Dim)
assert isinstance(Kinit_dim, Dim)
assert isinstance(K_dim, Dim)
initial_state = scope[self.init]
sample_prev = sample.order(K_dim)[Kinit_dim]
sample_prev = t.cat([
initial_state[None, ...],
sample.order(T_dim)[:-1],
], 0)[T_dim]
scope = {**scope}
scope['prev'] = sample_prev
return self.trans.log_prob(sample, scope, None, None, None)