RAFT Dask API#

Dask-based Multi-Node Multi-GPU Communicator#

class raft_dask.common.Comms(comms_p2p=False, client=None, verbose=False, streams_per_handle=0, nccl_root_location='scheduler')[source]#

Initializes and manages underlying NCCL and UCX comms handles across the workers of a Dask cluster. It is expected that init() will be called explicitly. It is recommended to also call destroy() when the comms are no longer needed so the underlying resources can be cleaned up. This class is not meant to be thread-safe.

Examples

# The following code block assumes we have wrapped a C++
# function in a Python function called `run_algorithm`,
# which takes a `raft::handle_t` as a single argument.
# Once the `Comms` instance is successfully initialized,
# the underlying `raft::handle_t` will contain an instance
# of `raft::comms::comms_t`

from dask_cuda import LocalCUDACluster
from dask.distributed import Client

from raft.dask.common import Comms, local_handle

cluster = LocalCUDACluster()
client = Client(cluster)

def _use_comms(sessionId):
    return run_algorithm(local_handle(sessionId))

comms = Comms(client=client)
comms.init()

futures = [client.submit(_use_comms,
                         comms.sessionId,
                         workers=[w],
                         pure=False) # Don't memoize
               for w in cb.worker_addresses]
wait(dfs, timeout=5)

comms.destroy()
client.close()
cluster.close()

Methods

destroy()

Shuts down initialized comms and cleans up resources.

init([workers])

Initializes the underlying comms.

worker_info(workers)

Builds a dictionary of { (worker_address, worker_port) :

create_nccl_uniqueid

destroy()[source]#

Shuts down initialized comms and cleans up resources. This will be called automatically by the Comms destructor, but may be called earlier to save resources.

init(workers=None)[source]#

Initializes the underlying comms. NCCL is required but UCX is only initialized if comms_p2p == True

Parameters:
workersSequence

Unique collection of workers for initializing comms.

worker_info(workers)[source]#
Builds a dictionary of { (worker_address, worker_port) :

(worker_rank, worker_port ) }