select_top_k_indices_sorted#

bamengine.utils.select_top_k_indices_sorted(values, k, descending=True)[source]#

Returns indices of k smallest/largest elements, sorted along the last axis.

Identifies the k elements (smallest or largest based on the descending flag) along the last axis of the input N-dimensional array values. It then returns the original indices of these k elements, sorted such that np.take_along_axis(values, returned_indices, axis=-1) yields values in the specified order (ascending or descending).

Parameters:
  • values (numpy.ndarray) – N-dimensional array of numerical values. Selection occurs along the last axis.

  • k (int) – The number of indices to select.

  • descending (bool, optional) – If True (default), selects k largest elements (sorted largest to smallest). If False, selects k smallest elements (sorted smallest to largest).

Returns:

N-dimensional array of integer indices, shaped values.shape[:-1] + (selected_k,). selected_k is k (or 0 if k<=0, or values.shape[-1] if k exceeds the last axis dimension). Indices refer to positions along the last axis of values.

Return type:

numpy.ndarray

Notes

  • Operates on the last axis of N-dimensional arrays.

  • Uses np.argpartition for efficient selection when k is less than the size of the last dimension.

  • Scalar inputs are treated as 1-element arrays.

  • Handles empty input arrays and k <= 0 by returning appropriately shaped empty index arrays.