select_top_k_indices_sorted#
- bamengine.utils.select_top_k_indices_sorted(values, k, descending=True)[source]#
Returns indices of k smallest/largest elements, sorted along the last axis.
Identifies the k elements (smallest or largest based on the descending flag) along the last axis of the input N-dimensional array values. It then returns the original indices of these k elements, sorted such that np.take_along_axis(values, returned_indices, axis=-1) yields values in the specified order (ascending or descending).
- Parameters:
values (
numpy.ndarray) – N-dimensional array of numerical values. Selection occurs along the last axis.k (
int) – The number of indices to select.descending (
bool, optional) – If True (default), selects k largest elements (sorted largest to smallest). If False, selects k smallest elements (sorted smallest to largest).
- Returns:
N-dimensional array of integer indices, shaped values.shape[:-1] + (selected_k,). selected_k is k (or 0 if k<=0, or values.shape[-1] if k exceeds the last axis dimension). Indices refer to positions along the last axis of values.
- Return type:
Notes
Operates on the last axis of N-dimensional arrays.
Uses np.argpartition for efficient selection when k is less than the size of the last dimension.
Scalar inputs are treated as 1-element arrays.
Handles empty input arrays and k <= 0 by returning appropriately shaped empty index arrays.