Shortcuts

Source code for torch.distributions.distribution

import torch
import warnings
from torch.distributions import constraints
from torch.distributions.utils import lazy_property


[docs]class Distribution(object): r""" Distribution is the abstract base class for probability distributions. """ has_rsample = False has_enumerate_support = False _validate_args = False support = None arg_constraints = {} @staticmethod def set_default_validate_args(value): if value not in [True, False]: raise ValueError Distribution._validate_args = value def __init__(self, batch_shape=torch.Size(), event_shape=torch.Size(), validate_args=None): self._batch_shape = batch_shape self._event_shape = event_shape if validate_args is not None: self._validate_args = validate_args if self._validate_args: for param, constraint in self.arg_constraints.items(): if constraints.is_dependent(constraint): continue # skip constraints that cannot be checked if param not in self.__dict__ and isinstance(getattr(type(self), param), lazy_property): continue # skip checking lazily-constructed args if not constraint.check(getattr(self, param)).all(): raise ValueError("The parameter {} has invalid values".format(param)) super(Distribution, self).__init__()
[docs] def expand(self, batch_shape, _instance=None): """ Returns a new distribution instance (or populates an existing instance provided by a derived class) with batch dimensions expanded to `batch_shape`. This method calls :class:`~torch.Tensor.expand` on the distribution's parameters. As such, this does not allocate new memory for the expanded distribution instance. Additionally, this does not repeat any args checking or parameter broadcasting in `__init__.py`, when an instance is first created. Args: batch_shape (torch.Size): the desired expanded size. _instance: new instance provided by subclasses that need to override `.expand`. Returns: New distribution instance with batch dimensions expanded to `batch_size`. """ raise NotImplementedError
@property def batch_shape(self): """ Returns the shape over which parameters are batched. """ return self._batch_shape @property def event_shape(self): """ Returns the shape of a single sample (without batching). """ return self._event_shape @property def arg_constraints(self): """ Returns a dictionary from argument names to :class:`~torch.distributions.constraints.Constraint` objects that should be satisfied by each argument of this distribution. Args that are not tensors need not appear in this dict. """ raise NotImplementedError @property def support(self): """ Returns a :class:`~torch.distributions.constraints.Constraint` object representing this distribution's support. """ raise NotImplementedError @property def mean(self): """ Returns the mean of the distribution. """ raise NotImplementedError @property def variance(self): """ Returns the variance of the distribution. """ raise NotImplementedError @property def stddev(self): """ Returns the standard deviation of the distribution. """ return self.variance.sqrt()
[docs] def sample(self, sample_shape=torch.Size()): """ Generates a sample_shape shaped sample or sample_shape shaped batch of samples if the distribution parameters are batched. """ with torch.no_grad(): return self.rsample(sample_shape)
[docs] def rsample(self, sample_shape=torch.Size()): """ Generates a sample_shape shaped reparameterized sample or sample_shape shaped batch of reparameterized samples if the distribution parameters are batched. """ raise NotImplementedError
[docs] def sample_n(self, n): """ Generates n samples or n batches of samples if the distribution parameters are batched. """ warnings.warn('sample_n will be deprecated. Use .sample((n,)) instead', UserWarning) return self.sample(torch.Size((n,)))
[docs] def log_prob(self, value): """ Returns the log of the probability density/mass function evaluated at `value`. Args: value (Tensor): """ raise NotImplementedError
[docs] def cdf(self, value): """ Returns the cumulative density/mass function evaluated at `value`. Args: value (Tensor): """ raise NotImplementedError
[docs] def icdf(self, value): """ Returns the inverse cumulative density/mass function evaluated at `value`. Args: value (Tensor): """ raise NotImplementedError
[docs] def enumerate_support(self, expand=True): """ Returns tensor containing all values supported by a discrete distribution. The result will enumerate over dimension 0, so the shape of the result will be `(cardinality,) + batch_shape + event_shape` (where `event_shape = ()` for univariate distributions). Note that this enumerates over all batched tensors in lock-step `[[0, 0], [1, 1], ...]`. With `expand=False`, enumeration happens along dim 0, but with the remaining batch dimensions being singleton dimensions, `[[0], [1], ..`. To iterate over the full Cartesian product use `itertools.product(m.enumerate_support())`. Args: expand (bool): whether to expand the support over the batch dims to match the distribution's `batch_shape`. Returns: Tensor iterating over dimension 0. """ raise NotImplementedError
[docs] def entropy(self): """ Returns entropy of distribution, batched over batch_shape. Returns: Tensor of shape batch_shape. """ raise NotImplementedError
[docs] def perplexity(self): """ Returns perplexity of distribution, batched over batch_shape. Returns: Tensor of shape batch_shape. """ return torch.exp(self.entropy())
def _extended_shape(self, sample_shape=torch.Size()): """ Returns the size of the sample returned by the distribution, given a `sample_shape`. Note, that the batch and event shapes of a distribution instance are fixed at the time of construction. If this is empty, the returned shape is upcast to (1,). Args: sample_shape (torch.Size): the size of the sample to be drawn. """ if not isinstance(sample_shape, torch.Size): sample_shape = torch.Size(sample_shape) return sample_shape + self._batch_shape + self._event_shape def _validate_sample(self, value): """ Argument validation for distribution methods such as `log_prob`, `cdf` and `icdf`. The rightmost dimensions of a value to be scored via these methods must agree with the distribution's batch and event shapes. Args: value (Tensor): the tensor whose log probability is to be computed by the `log_prob` method. Raises ValueError: when the rightmost dimensions of `value` do not match the distribution's batch and event shapes. """ if not isinstance(value, torch.Tensor): raise ValueError('The value argument to log_prob must be a Tensor') event_dim_start = len(value.size()) - len(self._event_shape) if value.size()[event_dim_start:] != self._event_shape: raise ValueError('The right-most size of value must match event_shape: {} vs {}.'. format(value.size(), self._event_shape)) actual_shape = value.size() expected_shape = self._batch_shape + self._event_shape for i, j in zip(reversed(actual_shape), reversed(expected_shape)): if i != 1 and j != 1 and i != j: raise ValueError('Value is not broadcastable with batch_shape+event_shape: {} vs {}.'. format(actual_shape, expected_shape)) if not self.support.check(value).all(): raise ValueError('The value argument must be within the support') def _get_checked_instance(self, cls, _instance=None): if _instance is None and type(self).__init__ != cls.__init__: raise NotImplementedError("Subclass {} of {} that defines a custom __init__ method " "must also define a custom .expand() method.". format(self.__class__.__name__, cls.__name__)) return self.__new__(type(self)) if _instance is None else _instance def __repr__(self): param_names = [k for k, _ in self.arg_constraints.items() if k in self.__dict__] args_string = ', '.join(['{}: {}'.format(p, self.__dict__[p] if self.__dict__[p].numel() == 1 else self.__dict__[p].size()) for p in param_names]) return self.__class__.__name__ + '(' + args_string + ')'

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources