Matrix Ordering#
Argmax#
#include <raft/matrix/argmax.cuh>
namespace raft::matrix
-
template<typename math_t, typename idx_t, typename matrix_idx_t>
void argmax(raft::resources const &handle, raft::device_matrix_view<const math_t, matrix_idx_t, row_major> in, raft::device_vector_view<idx_t, matrix_idx_t> out)# Argmax: find the col idx with maximum value for each row.
- Parameters:
handle – [in] raft handle
in – [in] input matrix of size (n_rows, n_cols)
out – [out] output vector of size n_rows
Argmin#
#include <raft/matrix/argmin.cuh>
namespace raft::matrix
-
template<typename math_t, typename idx_t, typename matrix_idx_t>
void argmin(raft::resources const &handle, raft::device_matrix_view<const math_t, matrix_idx_t, row_major> in, raft::device_vector_view<idx_t, matrix_idx_t> out)# Argmin: find the col idx with minimum value for each row.
- Parameters:
handle – [in] raft handle
in – [in] input matrix of size (n_rows, n_cols)
out – [out] output vector of size n_rows
Select-K#
#include <raft/matrix/select_k.cuh>
namespace raft::matrix
-
template<typename T, typename IdxT>
void select_k(raft::resources const &handle, raft::device_matrix_view<const T, int64_t, row_major> in_val, std::optional<raft::device_matrix_view<const IdxT, int64_t, row_major>> in_idx, raft::device_matrix_view<T, int64_t, row_major> out_val, raft::device_matrix_view<IdxT, int64_t, row_major> out_idx, bool select_min, bool sorted = false)# Select k smallest or largest key/values from each row in the input data.
If you think of the input data
in_val
as a row-major matrix withlen
columns andbatch_size
rows, then this function selectsk
smallest/largest values in each row and fills in the row-major matrixout_val
of size (batch_size, k).Example usage
using namespace raft; // get a 2D row-major array of values to search through auto in_values = {... input device_matrix_view<const float, int64_t, row_major> ...} // prepare output arrays auto out_extents = make_extents<int64_t>(in_values.extent(0), k); auto out_values = make_device_mdarray<float>(handle, out_extents); auto out_indices = make_device_mdarray<int64_t>(handle, out_extents); // search `k` smallest values in each row matrix::select_k<float, int64_t>( handle, in_values, std::nullopt, out_values.view(), out_indices.view(), true);
- Template Parameters:
T – the type of the keys (what is being compared).
IdxT – the index type (what is being selected together with the keys).
- Parameters:
handle – [in] container of reusable resources
in_val – [in] inputs values [batch_size, len]; these are compared and selected.
in_idx – [in] optional input payload [batch_size, len]; typically, these are indices of the corresponding
in_val
. Ifin_idx
isstd::nullopt
, a contiguous array0...len-1
is implied.out_val – [out] output values [batch_size, k]; the k smallest/largest values from each row of the
in_val
.out_idx – [out] output payload (e.g. indices) [batch_size, k]; the payload selected together with
out_val
.select_min – [in] whether to select k smallest (true) or largest (false) keys.
sorted – [in] whether to make sure selected pairs are sorted by value
Column-wise Sort#
#include <raft/matrix/col_wise_sort.cuh>
namespace raft::matrix
-
template<typename in_t, typename out_t, typename matrix_idx_t, typename sorted_keys_t>
void sort_cols_per_row(raft::resources const &handle, raft::device_matrix_view<const in_t, matrix_idx_t, raft::row_major> in, raft::device_matrix_view<out_t, matrix_idx_t, raft::row_major> out, sorted_keys_t &&sorted_keys_opt)# sort columns within each row of row-major input matrix and return sorted indexes modelled as key-value sort with key being input matrix and value being index of values
- Template Parameters:
in_t – element type of input matrix
out_t – element type of output matrix
matrix_idx_t – integer type for matrix indexing
sorted_keys_t – std::optional<raft::device_matrix_view<in_t, matrix_idx_t, raft::row_major>>
sorted_keys_opt
- Parameters:
handle – [in] raft handle
in – [in] input matrix
out – [out] output value(index) matrix
sorted_keys_opt – [out] std::optional, output matrix for sorted keys (input)
-
template<typename ...Args, typename = std::enable_if_t<sizeof...(Args) == 3>>
void sort_cols_per_row(Args... args)# Overload of
sort_keys_per_row
to help the compiler find the above overload, in case users pass instd::nullopt
for one or both of the optional arguments.Please see above for documentation of
sort_keys_per_row
.