Shortcuts

Source code for torch.futures

from typing import Generic, TypeVar

import torch


class _PyFuture(torch._C.Future):
    def wait(self):
        r"""
        Block until the value of this ``Future`` is ready.

        Returns:
            The value held by this ``Future``. If the function (callback or RPC)
            creating the value has thrown an error, this ``wait`` method will
            also throw an error.
        """
        return super().wait()

    def then(self, callback):
        r"""
        Append the given callback function to this ``Future``, which will be run
        when the ``Future`` is completed.  Multiple callbacks can be added to
        the same ``Future``, and will be invoked in the same order as they were
        added. The callback must take one argument, which is the reference to
        this ``Future``. The callback function can use the ``Future.wait()`` API
        to get the value.

        Arguments:
            callback(``Callable``): a ``Callable`` that takes this ``Future`` as
                                    the only argument.

        Returns:
            A new ``Future`` object that holds the return value of the
            ``callback`` and will be marked as completed when the given
            ``callback`` finishes.

        Example::
            >>> import torch
            >>>
            >>> def callback(fut):
            >>>     print(f"RPC return value is {fut.wait()}.")
            >>>
            >>> fut = torch.futures.Future()
            >>> # The inserted callback will print the return value when
            >>> # receiving the response from "worker1"
            >>> cb_fut = fut.then(callback)
            >>> chain_cb_fut = cb_fut.then(
            >>>     lambda x : print(f"Chained cb done. {x.wait()}")
            >>> )
            >>> fut.set_result(5)
            >>>
            >>> # Outputs are:
            >>> # RPC return value is 5.
            >>> # Chained cb done. None
        """
        return super().then(callback)

    def set_result(self, result):
        r"""
        Set the result for this ``Future``, which will mark this ``Future`` as
        completed and trigger all attached callbacks. Note that a ``Future``
        cannot be marked completed twice.

        Arguments:
            result (object): the result object of this ``Future``.

        Example::
            >>> import threading
            >>> import time
            >>> import torch
            >>>
            >>> def slow_set_future(fut, value):
            >>>     time.sleep(0.5)
            >>>     fut.set_result(value)
            >>>
            >>> fut = torch.futures.Future()
            >>> t = threading.Thread(
            >>>     target=slow_set_future,
            >>>     args=(fut, torch.ones(2) * 3)
            >>> )
            >>> t.start()
            >>>
            >>> print(fut.wait())  # tensor([3., 3.])
            >>> t.join()
        """
        super().set_result(result)


[docs]def collect_all(futures): r""" Collects the provided :class:`~torch.futures.Future` objects into a single combined :class:`~torch.futures.Future` that is completed when all of the sub-futures are completed. Arguments: futures (list): a list of :class:`~torch.futures.Future` objects. Returns: Returns a :class:`~torch.futures.Future` object to a list of the passed in Futures. Example:: >>> import torch >>> >>> fut0 = torch.futures.Future() >>> fut1 = torch.futures.Future() >>> >>> fut = torch.futures.collect_all([fut0, fut1]) >>> >>> fut0.set_result(0) >>> fut1.set_result(1) >>> >>> fut_list = fut.wait() >>> print(f"fut0 result = {fut_list[0].wait()}") >>> print(f"fut1 result = {fut_list[1].wait()}") >>> # outputs: >>> # fut0 result = 0 >>> # fut1 result = 1 """ return torch._C._collect_all(futures)
[docs]def wait_all(futures): r""" Waits for all provided futures to be complete, and returns the list of completed values. Arguments: futures (list): a list of :class:`~torch.futures.Future` object. Returns: A list of the completed :class:`~torch.futures.Future` results. This method will throw an error if ``wait`` on any :class:`~torch.futures.Future` throws. """ return [fut.wait() for fut in torch._C._collect_all(futures).wait()]
T = TypeVar("T") GenericWithOneTypeVar = Generic[T] try: # Combine the implementation class and the type class. class Future(_PyFuture, GenericWithOneTypeVar): r""" Wrapper around a ``torch._C.Future`` which encapsulates an asynchronous execution of a callable, e.g. :meth:`~torch.distributed.rpc.rpc_async`. It also exposes a set of APIs to add callback functions and set results. """ pass except TypeError as exc: # TypeError: metaclass conflict: the metaclass of a derived class # must be a (non-strict) subclass of the metaclasses of all its bases class FutureMeta(_PyFuture.__class__, GenericWithOneTypeVar.__class__): pass # Combine the implementation class and the type class. class Future(_PyFuture, GenericWithOneTypeVar, metaclass=FutureMeta): r""" Wrapper around a ``torch._C.Future`` which encapsulates an asynchronous execution of a callable, e.g. :meth:`~torch.distributed.rpc.rpc_async`. It also exposes a set of APIs to add callback functions and set results. """ pass

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources