torch.futures¶
Warning
The torch.futures package is experimental and subject to change.
This package provides a Future type that encapsulates
an asynchronous execution and a set of utility functions to simplify operations
on Future objects. Currently, the
Future type is primarily used by the
Distributed RPC Framework.
-
class
torch.futures.Future¶ Wrapper around a
torch._C.Futurewhich encapsulates an asynchronous execution of a callable, e.g.rpc_async(). It also exposes a set of APIs to add callback functions and set results.-
set_result(result)¶ Set the result for this
Future, which will mark thisFutureas completed and trigger all attached callbacks. Note that aFuturecannot be marked completed twice.- Parameters
result (object) – the result object of this
Future.
- Example::
>>> import threading >>> import time >>> import torch >>> >>> def slow_set_future(fut, value): >>> time.sleep(0.5) >>> fut.set_result(value) >>> >>> fut = torch.futures.Future() >>> t = threading.Thread( >>> target=slow_set_future, >>> args=(fut, torch.ones(2) * 3) >>> ) >>> t.start() >>> >>> print(fut.wait()) # tensor([3., 3.]) >>> t.join()
-
then(callback)¶ Append the given callback function to this
Future, which will be run when theFutureis completed. Multiple callbacks can be added to the sameFuture, and will be invoked in the same order as they were added. The callback must take one argument, which is the reference to thisFuture. The callback function can use theFuture.wait()API to get the value.- Parameters
callback (
Callable) – aCallablethat takes thisFutureas the only argument.- Returns
A new
Futureobject that holds the return value of thecallbackand will be marked as completed when the givencallbackfinishes.
- Example::
>>> import torch >>> >>> def callback(fut): >>> print(f"RPC return value is {fut.wait()}.") >>> >>> fut = torch.futures.Future() >>> # The inserted callback will print the return value when >>> # receiving the response from "worker1" >>> cb_fut = fut.then(callback) >>> chain_cb_fut = cb_fut.then( >>> lambda x : print(f"Chained cb done. {x.wait()}") >>> ) >>> fut.set_result(5) >>> >>> # Outputs are: >>> # RPC return value is 5. >>> # Chained cb done. None
-
wait()¶ Block until the value of this
Futureis ready.- Returns
The value held by this
Future. If the function (callback or RPC) creating the value has thrown an error, thiswaitmethod will also throw an error.
-
-
torch.futures.collect_all(futures)[source]¶ Collects the provided
Futureobjects into a single combinedFuturethat is completed when all of the sub-futures are completed.- Parameters
- Returns
Returns a
Futureobject to a list of the passed in Futures.
- Example::
>>> import torch >>> >>> fut0 = torch.futures.Future() >>> fut1 = torch.futures.Future() >>> >>> fut = torch.futures.collect_all([fut0, fut1]) >>> >>> fut0.set_result(0) >>> fut1.set_result(1) >>> >>> fut_list = fut.wait() >>> print(f"fut0 result = {fut_list[0].wait()}") >>> print(f"fut1 result = {fut_list[1].wait()}") >>> # outputs: >>> # fut0 result = 0 >>> # fut1 result = 1