Shortcuts

Source code for torch.utils.data.distributed

import math
import torch
from . import Sampler
import torch.distributed as dist


[docs]class DistributedSampler(Sampler): r"""Sampler that restricts data loading to a subset of the dataset. It is especially useful in conjunction with :class:`torch.nn.parallel.DistributedDataParallel`. In such a case, each process can pass a :class:`~torch.utils.data.DistributedSampler` instance as a :class:`~torch.utils.data.DataLoader` sampler, and load a subset of the original dataset that is exclusive to it. .. note:: Dataset is assumed to be of constant size. Arguments: dataset: Dataset used for sampling. num_replicas (int, optional): Number of processes participating in distributed training. By default, :attr:`rank` is retrieved from the current distributed group. rank (int, optional): Rank of the current process within :attr:`num_replicas`. By default, :attr:`rank` is retrieved from the current distributed group. shuffle (bool, optional): If ``True`` (default), sampler will shuffle the indices. seed (int, optional): random seed used to shuffle the sampler if :attr:`shuffle=True`. This number should be identical across all processes in the distributed group. Default: ``0``. .. warning:: In distributed mode, calling the :meth:`set_epoch` method at the beginning of each epoch **before** creating the :class:`DataLoader` iterator is necessary to make shuffling work properly across multiple epochs. Otherwise, the same ordering will be always used. Example:: >>> sampler = DistributedSampler(dataset) if is_distributed else None >>> loader = DataLoader(dataset, shuffle=(sampler is None), ... sampler=sampler) >>> for epoch in range(start_epoch, n_epochs): ... if is_distributed: ... sampler.set_epoch(epoch) ... train(loader) """ def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True, seed=0): if num_replicas is None: if not dist.is_available(): raise RuntimeError("Requires distributed package to be available") num_replicas = dist.get_world_size() if rank is None: if not dist.is_available(): raise RuntimeError("Requires distributed package to be available") rank = dist.get_rank() self.dataset = dataset self.num_replicas = num_replicas self.rank = rank self.epoch = 0 self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas)) self.total_size = self.num_samples * self.num_replicas self.shuffle = shuffle self.seed = seed def __iter__(self): if self.shuffle: # deterministically shuffle based on epoch and seed g = torch.Generator() g.manual_seed(self.seed + self.epoch) indices = torch.randperm(len(self.dataset), generator=g).tolist() else: indices = list(range(len(self.dataset))) # add extra samples to make it evenly divisible indices += indices[:(self.total_size - len(indices))] assert len(indices) == self.total_size # subsample indices = indices[self.rank:self.total_size:self.num_replicas] assert len(indices) == self.num_samples return iter(indices) def __len__(self): return self.num_samples def set_epoch(self, epoch): r""" Sets the epoch for this sampler. When :attr:`shuffle=True`, this ensures all replicas use a different random ordering for each epoch. Otherwise, the next iteration of this sampler will yield the same ordering. Arguments: epoch (int): Epoch number. """ self.epoch = epoch

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources