cuML on GPU and CPU#
cuML is a Scikit-learn-based suite of fast, GPU-accelerated machine learning algorithms designed for data science and analytical tasks. Starting with version 23.10, a new version of cuML can also be run on CPU systems, increasing its ease of use (without code changes) in the following manners:
Allow users to prototype in systems without GPUs.
Allow library integrations without the need of dispatching and boilerplate code.
Allow users to train on one type of system and infer with the other in a subset of estimators (that will grow with each version).
Provide compatibility with the GPU/CPU open source pydata ecosystem.
The majority of estimators of cuML can run in both CPU and GPU systems, with a subset of them allowing exporting models between GPU and CPU systems. The following table shows support for the most common estimators:
Category |
Algorithm |
Supports Execution on CPU |
Supports Exporting between CPU and GPU |
---|---|---|---|
Clustering |
Density-Based Spatial Clustering of Applications with Noise (DBSCAN) |
Yes |
No |
Hierarchical Density-Based Spatial Clustering of Applications with Noise (HDBSCAN) |
Yes |
Partial |
|
K-Means |
Yes |
No |
|
Single-Linkage Agglomerative Clustering |
No |
No |
|
Dimensionality Reduction |
Principal Components Analysis (PCA) |
Yes |
Yes |
Incremental PCA |
No |
No |
|
Truncated Singular Value Decomposition (tSVD) |
Yes |
Yes |
|
Uniform Manifold Approximation and Projection (UMAP) |
Yes |
Partial |
|
Random Projection |
No |
No |
|
t-Distributed Stochastic Neighbor Embedding (TSNE) |
No |
No |
|
Linear Models for Regression or Classification |
Linear Regression (OLS) |
Yes |
Yes |
Linear Regression with Lasso or Ridge Regularization |
Yes |
Yes |
|
ElasticNet Regression |
Yes |
Yes |
|
LARS Regression |
No |
No |
|
Logistic Regression |
Yes |
Yes |
|
Naive Bayes |
No |
No |
|
Solvers |
Yes |
||
Nonlinear Models for Regression or Classification |
Random Forest (RF) Classification |
No |
Partial |
Random Forest (RF) Regression |
No |
Partial |
|
Inference for decision tree-based models |
No |
No |
|
Nearest Neighbors (NN) |
Yes |
Yes |
|
K-Nearest Neighbors (KNN) Classification |
Yes |
Yes |
|
K-Nearest Neighbors (KNN) Regression |
Yes |
Yes |
|
Support Vector Machine Classifier (SVC) |
No |
No |
|
Epsilon-Support Vector Regression (SVR) |
No |
No |
|
Time Series |
Holt-Winters Exponential Smoothing |
No |
No |
Auto-regressive Integrated Moving Average (ARIMA) |
No |
No |
This allows the same code to be guaranteed to run in both GPU and CPU systems. Version 23.12 is scheduled to add the following algorithms: Random Forest and Support Vector Machine estimators.
Installation#
For GPU systems, cuML still follows the [RAPIDS requirements] and nothing has changed for installing it. The cuML package and wheels are universal and can run in both GPU and CPU modes. For installing in CPU systems, similar to other packages it can be installed from conda/mamba with:
mamba install -c rapidsai -c nvidia -c conda-forge cuml-cpu=23.10
# mamba install -c rapidsai-nightly -c nvidia -c conda-forge cuml-cpu=23.12 # for nightly builds
cuML 23.10 supports Linux and WSL2 on GPU and CPU systems using conda.
cuML 23.12 will bring support for pip wheels and macos support for CPU execution.
How to Use#
There are two main ways to use the CPU capabilities of cuML:
1. Using CPU Package directly#
The CPU package, cuml-cpu
is a subset of the cuml
package, so besides the difference in installation there is no changes needed to the code of supported estimators to run code. For example, the following script can be run both in a system with GPU and cuml
, as well as a system without GPU and cuml-cpu
:
[1]:
import cuml # no change is needed for even the importing!
import pandas as pd
from cuml.manifold.umap import UMAP
from sklearn import datasets
from sklearn.model_selection import train_test_split
from sklearn.manifold import trustworthiness
# load the iris dataset from sklearn and extract the required information
iris = datasets.load_iris()
dataset = iris.data
iris_df = pd.DataFrame(iris.data, columns=iris.feature_names)
# define the cuml UMAP model and use fit_transform function to obtain the low dimensional output of the input dataset
embedding = UMAP(
n_neighbors=10, min_dist=0.01, init="random"
).fit_transform(iris_df)
# calculate the trust worthiness of the results obtaind from the cuml UMAP
trust = trustworthiness(iris_df, embedding)
print(trust)
0.9775774647887324
This allows to prototype on CPU systems and then run code on GPU servers, or the other way around. Some estimators support training on one type of system and then exporting models to the other type, as can be seen in the corresponding section.
2. Managing Execution Platform with GPU package#
Additionally to allowing the same code to be run in CPU systems, users can control which device executes parts of the code. So in addition to the first example that can just be run in a CPU system with cuml-cpu
, a system with the full cuML can execute in CPU mode as well.
For example, using the following data:
[2]:
import cuml
from cuml.neighbors import NearestNeighbors
from cuml.datasets import make_regression, make_blobs
from cuml.model_selection import train_test_split
X_blobs, y_blobs = make_blobs(n_samples=2000,
n_features=20)
X_train_blobs, X_test_blobs, y_train_blobs, y_test_blobs = train_test_split(X_blobs,
y_blobs,
test_size=0.2, shuffle=True)
X_reg, y_reg = make_regression(n_samples=2000,
n_features=20)
X_train_reg, X_test_reg, y_train_reg, y_tes_reg = train_test_split(X_reg,
y_reg,
test_size=0.2,
shuffle=True)
There are two ways to control the execution of the code:
a) using_device_type
context manager:#
[3]:
from cuml.neighbors import NearestNeighbors
from cuml.common.device_selection import using_device_type
nn = NearestNeighbors()
with using_device_type('cpu'):
nn.fit(X_train_blobs)
nearest_neighbors = nn.kneighbors(X_test_blobs)
This allows to prototype but also to run different estimators on different devices, for example in the case where data is small so that moving the data around wouldn’t allow the GPU to accelerate an estimator.
Additionally, it allows to run estimators using unsupported parameter:
from cuml.manifold import UMAP
umap_model = UMAP(angular_rp_forest=True) # `angular_rp_forest` hyperparameter only available in UMAP library
with using_device_type('cpu'):
umap_model.fit(X_train_blobs) # will run the UMAP library with the hyperparameter
with using_device_type('gpu'):
transformed = umap_model.transform(X_test_blobs) # will run the cuML implementation of UMAP, ignoring the unsupported parameter.
An upcoming feature will allow for this to also dispatch automatically. This can be very useful for library integrators, so that if users use parameters not supported on GPUs, the code automatically will dispatch to a CPU implementation.
b) Global configuration.#
By default, cuml
will execute estimators on the GPU/device. But it also allows a global configuration option to change the default device, which could be useful in shared systems where cuML is running alongside deep learning frameworks that are occupying most of a GPU. This can be accomplished with the set_global_device_type
function:
[4]:
from cuml.common.device_selection import set_global_device_type, get_global_device_type
initial_device_type = get_global_device_type()
print('default execution device:', initial_device_type)
default execution device: DeviceType.device
[5]:
set_global_device_type('cpu')
print('new device type:', get_global_device_type())
new device type: DeviceType.host
Cross Device Training and Inference Serialization#
As stated before, a subset of the estimators that can be executed on the CPU, also allow to serialize estimators trained on one type of device (CPU or GPU) and then deserialize it on the other one.
To do this, a simple API is provided. For example, To train a model on GPU but deploy it on CPU, first, train the estimator on device and save it to disk:
[6]:
import pickle
from cuml.linear_model import LinearRegression
lin_reg = LinearRegression()
lin_reg.fit(X_train_reg, y_train_reg)
pickle.dump(lin_reg, open("lin_reg.pkl", "wb"))
del lin_reg
/opt/conda/envs/docs/lib/python3.10/site-packages/cuml/internals/api_decorators.py:382: UserWarning: Starting from version 23.08, the new 'copy_X' parameter defaults to 'True', ensuring a copy of X is created after passing it to fit(), preventing any changes to the input, but with increased memory usage. This represents a change in behavior from previous versions. With `copy_X=False` a copy might still be created if necessary. Explicitly set 'copy_X' to either True or False to suppress this warning.
return init_func(self, *args, **filtered_kwargs)
Then, on the server/other device, recover the estimator on a node with cuml-cpu
installed:
[7]:
recovered_lin_reg = pickle.load(open("lin_reg.pkl", "rb"))
predictions = recovered_lin_reg.predict(X_test_reg)
print(predictions[0:10])
[[ -4.869449]
[ 19.966135]
[-33.391266]
[ 25.461655]
[ 72.99573 ]
[180.68678 ]
[-18.156178]
[ 52.316895]
[128.38304 ]
[-34.872208]]
Conclusions#
cuML’s CPU capabilities are designed to facilitate different usecases, lower the requirements to use the capabilities of cuML, as well as increasing the flexibility and capabilities of integration and deployment of the library.
Upcoming versions of cuML will increase the supported estimators, both for CPU execution as well as serializing/exporting models between systems with and without GPUs.