Source code for alan.Split

import math
import warnings

from .utils import *

class NoSplit:
    def split_args(self, name, sample, inputs_params, extra_log_factors, data, all_platedims):
        return [{
            'sample':sample, 
            'inputs_params':inputs_params, 
            'extra_log_factors':extra_log_factors, 
            'data':data,
            'all_platedims':all_platedims,
        }]

class NoCheckpoint(NoSplit):
    pass
no_checkpoint = NoCheckpoint()

class Checkpoint(NoSplit):
    pass
checkpoint = Checkpoint()

[docs] class Split: """ A class indicating how to split the computation along a plate. Always used as a ``computation_strategy=Split(...)`` keyword argument. Arguments: platename (str): The name of the plate to split. split_size (str): The size of each split. Note that this is the size of an individual split, not the total number of splits. That's useful, because you can set split_size so that the model fits in memory, and it should still fit in memory if the data gets bigger. """ def __init__(self, platename:str, split_size:int): assert isinstance(platename, str) assert isinstance(split_size, int) self.platename = platename self.split_size = split_size def splitdims(self, all_platedims): return SplitDims(self, all_platedims) def split_args(self, name, sample, inputs_params, extra_log_factors, data, all_platedims): if self.platename == name: split = self.splitdims(all_platedims) samples = split.split_dict(sample) inputs_paramss = split.split_dict(inputs_params) extra_log_factorss = split.split_dict(extra_log_factors) datas = split.split_dict(data) all_platedimss = split.split_all_platedimss else: samples = [sample] inputs_paramss = [inputs_params] extra_log_factorss = [extra_log_factors] datas = [data] all_platedimss = [all_platedims] del sample, inputs_params, extra_log_factors, data, all_platedims result = [] for (s, i, e, d, a) in zip(samples, inputs_paramss, extra_log_factorss, datas, all_platedimss): result.append({ 'sample' : s, 'inputs_params' : i, 'extra_log_factors' : e, 'data' : d, 'all_platedims' : a, }) return result
class SplitDims: def __init__(self, split: Split, all_platedims:dict[str, Dim]): self.split = split self.orig_dim = all_platedims[split.platename] orig_size = self.orig_dim.size split_size = self.split.split_size assert orig_size > split_size self.split_sizes = [*((orig_size//split_size)*[split_size])] remainder = orig_size%split_size assert 0 <= remainder if remainder != 0: self.split_sizes.append(remainder) if (split_size > 2) and self.split_sizes[-1] ==1: #Works around a bug in MPS device, which is triggered if you have a split size of 1 self.split_sizes[-2] = self.split_sizes[-2] - 1 self.split_sizes[-1] = self.split_sizes[-1] + 1 if self.split_sizes[-1] == 1: warnings.warn("Split size of 1, seems to trigger bugs on Apple mps device") self.split_dims = [Dim(f'{self.split.platename}_split_{i}', self.split_sizes[i]) for i in range(len(self.split_sizes))] self.split_all_platedimss = [{**all_platedims, self.split.platename: dim} for dim in self.split_dims] def split_tensor(self, x:Tensor): non_split_dims = [dim for dim in generic_dims(x) if dim is not self.orig_dim] xs = generic_order(x, [self.orig_dim, *non_split_dims]).split(self.split_sizes) return [generic_getitem(x, [split_dim, *non_split_dims]) for (x, split_dim) in zip(xs, self.split_dims)] def split_dict(self, d:dict): """ User-facing method, that takes a tree-structured dict, and returns a split tuple of tree-structured dicts. """ results = [{} for _ in self.split_sizes] self._split_dict_inner(d, results) return results def _split_dict_inner(self, d:dict, results:list[dict]): for (k, v) in d.items(): if isinstance(v, dict): new_results = [] for result in results: result[k] = {} new_results.append(result[k]) self._split_dict_inner(v, new_results) else: assert isinstance(v, Tensor) for result, val in zip(results, self.split_tensor(v)): result[k] = val