From b0f8bdfb1668a0741716c1f1971c568480f7e661 Mon Sep 17 00:00:00 2001 From: Arthur Date: Sun, 28 Sep 2025 14:09:34 +0200 Subject: [PATCH 01/22] implem done; clean-up & comments todo --- sklearn/utils/stats.py | 97 +++++++++++++++++++++++++++++++++++++++++- 1 file changed, 96 insertions(+), 1 deletion(-) diff --git a/sklearn/utils/stats.py b/sklearn/utils/stats.py index 35cadf0ca7372..777b68ff8b98e 100644 --- a/sklearn/utils/stats.py +++ b/sklearn/utils/stats.py @@ -7,7 +7,7 @@ ) -def _weighted_percentile( +def _old_weighted_percentile( array, sample_weight, percentile_rank=50, average=False, xp=None ): """Compute the weighted percentile. @@ -188,3 +188,98 @@ def _weighted_percentile( result = array[percentile_in_sorted, col_indices] return result[0] if n_dim == 1 else result + + +def _weighted_percentile_inner(x, w, target_sums, out, average, xp): + if x.size == 1: + out[:] = x + return + i = x.size // 2 + partitioner = xp.argpartition(x, x.size // 2) + w_left = w[partitioner[:i]] + sum_left = xp.sum(w_left) + j = xp.searchsorted(target_sums, sum_left) + target_sums[j:] -= sum_left + if j > 0: + _weighted_percentile_inner( + x[partitioner[:i]], w_left, target_sums[:j], out[:j], average, xp + ) + if j < target_sums.size: + idx_0 = xp.searchsorted(target_sums[j:], 0, side="right") + if idx_0 > 0: + out[j : j + idx_0] = ( + (x[partitioner[:i]].max() + x[partitioner[i:]].min()) / 2 + if average + else x[partitioner[:i]].max() + ) + j += idx_0 + if j < target_sums.size: + _weighted_percentile_inner( + x[partitioner[i:]], + w[partitioner[i:]], + target_sums[j:], + out[j:], + average, + xp, + ) + + +def _weighted_percentile( + array, sample_weight, percentile_rank=50, average=False, xp=None +): + xp, _, device = get_namespace_and_device(array) + + # `sample_weight` should follow `array` for dtypes + # XXX: ^ why? Also: why floating dtype? (is this really floating, I don't think so) + floating_dtype = _find_matching_floating_dtype(array, xp=xp) + array = xp.asarray(array, dtype=floating_dtype, device=device) + sample_weight = xp.asarray(sample_weight, dtype=floating_dtype, device=device) + + n_dim = array.ndim + if n_dim == 0: + return array + if array.ndim == 1: + array = xp.reshape(array, (-1, 1)) + n_features = array.shape[1] + + n_dim_percentile = 0 + q = xp.asarray([percentile_rank / 100]) + sorter = xp.argsort(q) + result = xp.empty((n_features, q.size), dtype=floating_dtype) + sorted_q = q[sorter] + result_sorted = xp.empty(q.size, dtype=floating_dtype) + + for feature_idx in range(n_features): + x = xp.ascontiguousarray(array[..., feature_idx]) + mask_nnan = ~xp.isnan(x) + x = x[mask_nnan] + if x.size == 0: + result[feature_idx, ...] = xp.nan + continue + w = ( + sample_weight[mask_nnan, feature_idx] + if sample_weight.ndim == 2 + else sample_weight[mask_nnan] + ) + mask_nz = w != 0 + if not mask_nz.all(): + w = w[mask_nz] + x[mask_nz] + weights_sum = xp.sum(w) + if weights_sum == 0: + result[feature_idx, ...] = xp.max(x) + continue + _weighted_percentile_inner( + x, + w, + target_sums=weights_sum * sorted_q, + out=result_sorted, + average=average, + xp=xp, + ) + result[feature_idx, sorter] = result_sorted + + if n_dim_percentile == 0: + result = result[..., 0] + + return result[0] if n_dim == 1 else result From 2140c82632130eea982ab3de0a928483330d200f Mon Sep 17 00:00:00 2001 From: Arthur Date: Sun, 28 Sep 2025 14:33:08 +0200 Subject: [PATCH 02/22] conform to array-API --- sklearn/utils/stats.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sklearn/utils/stats.py b/sklearn/utils/stats.py index 777b68ff8b98e..cef3bf9c09141 100644 --- a/sklearn/utils/stats.py +++ b/sklearn/utils/stats.py @@ -244,13 +244,13 @@ def _weighted_percentile( n_dim_percentile = 0 q = xp.asarray([percentile_rank / 100]) - sorter = xp.argsort(q) + sorter = xp.argsort(q, stable=False) result = xp.empty((n_features, q.size), dtype=floating_dtype) sorted_q = q[sorter] result_sorted = xp.empty(q.size, dtype=floating_dtype) for feature_idx in range(n_features): - x = xp.ascontiguousarray(array[..., feature_idx]) + x = array[..., feature_idx] mask_nnan = ~xp.isnan(x) x = x[mask_nnan] if x.size == 0: From 84c0240999131efbd98f1e6bf5d69212b1fcd6f9 Mon Sep 17 00:00:00 2001 From: Arthur Date: Sun, 28 Sep 2025 14:39:22 +0200 Subject: [PATCH 03/22] cleanup --- sklearn/utils/stats.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/sklearn/utils/stats.py b/sklearn/utils/stats.py index cef3bf9c09141..a4a07c2dc8d1b 100644 --- a/sklearn/utils/stats.py +++ b/sklearn/utils/stats.py @@ -200,27 +200,27 @@ def _weighted_percentile_inner(x, w, target_sums, out, average, xp): sum_left = xp.sum(w_left) j = xp.searchsorted(target_sums, sum_left) target_sums[j:] -= sum_left + x_left = None if j > 0: + x_left = x[partitioner[:i]] _weighted_percentile_inner( - x[partitioner[:i]], w_left, target_sums[:j], out[:j], average, xp + x_left, w_left, target_sums[:j], out[:j], average, xp ) if j < target_sums.size: idx_0 = xp.searchsorted(target_sums[j:], 0, side="right") if idx_0 > 0: + x_left = x[partitioner[:i]] if x_left is None else x_left out[j : j + idx_0] = ( - (x[partitioner[:i]].max() + x[partitioner[i:]].min()) / 2 + (x_left.max() + x[partitioner[i:]].min()) / 2 if average - else x[partitioner[:i]].max() + else x_left.max() ) j += idx_0 if j < target_sums.size: + x_right = x[partitioner[i:]] + w_right = w[partitioner[i:]] _weighted_percentile_inner( - x[partitioner[i:]], - w[partitioner[i:]], - target_sums[j:], - out[j:], - average, - xp, + x_right, w_right, target_sums[j:], out[j:], average, xp ) From 78226738c9b44efb20854d2cbbe3be913dce03ad Mon Sep 17 00:00:00 2001 From: Arthur Date: Sun, 28 Sep 2025 14:52:15 +0200 Subject: [PATCH 04/22] comments; docstring; cleanups --- sklearn/utils/stats.py | 206 +++++++++++------------------------------ 1 file changed, 56 insertions(+), 150 deletions(-) diff --git a/sklearn/utils/stats.py b/sklearn/utils/stats.py index a4a07c2dc8d1b..9165ad73e4a88 100644 --- a/sklearn/utils/stats.py +++ b/sklearn/utils/stats.py @@ -7,7 +7,43 @@ ) -def _old_weighted_percentile( +def _weighted_percentile_inner(x, w, target_sums, out, average, xp): + if x.size == 1: + out[:] = x + return + i = x.size // 2 + partitioner = xp.argpartition(x, i) + w_left = w[partitioner[:i]] + sum_left = xp.sum(w_left) + j = xp.searchsorted(target_sums, sum_left) + target_sums[j:] -= sum_left + x_left = None + if j > 0: + # some quantiles are to be found on the left side of the partition + x_left = x[partitioner[:i]] + _weighted_percentile_inner( + x_left, w_left, target_sums[:j], out[:j], average, xp + ) + if j >= target_sums.size: + return + idx_0 = xp.searchsorted(target_sums[j:], 0, side="right") + if idx_0 > 0: + # some quantiles are precisely at the index of the partition + x_left = x[partitioner[:i]] if x_left is None else x_left + out[j : j + idx_0] = ( + (x_left.max() + x[partitioner[i:]].min()) / 2 if average else x_left.max() + ) + j += idx_0 + if j < target_sums.size: + # some quantiles are to be found on the right side of the partition + x_right = x[partitioner[i:]] + w_right = w[partitioner[i:]] + _weighted_percentile_inner( + x_right, w_right, target_sums[j:], out[j:], average, xp + ) + + +def _weighted_percentile( array, sample_weight, percentile_rank=50, average=False, xp=None ): """Compute the weighted percentile. @@ -83,157 +119,13 @@ def _old_weighted_percentile( Weighted percentile at the requested probability level. """ xp, _, device = get_namespace_and_device(array) - # `sample_weight` should follow `array` for dtypes - floating_dtype = _find_matching_floating_dtype(array, xp=xp) - array = xp.asarray(array, dtype=floating_dtype, device=device) - sample_weight = xp.asarray(sample_weight, dtype=floating_dtype, device=device) - - n_dim = array.ndim - if n_dim == 0: - return array - if array.ndim == 1: - array = xp.reshape(array, (-1, 1)) - # When sample_weight 1D, repeat for each array.shape[1] - if array.shape != sample_weight.shape and array.shape[0] == sample_weight.shape[0]: - sample_weight = xp.tile(sample_weight, (array.shape[1], 1)).T - # Sort `array` and `sample_weight` along axis=0: - sorted_idx = xp.argsort(array, axis=0) - sorted_weights = xp.take_along_axis(sample_weight, sorted_idx, axis=0) - - # Set NaN values in `sample_weight` to 0. Only perform this operation if NaN - # values present to avoid temporary allocations of size `(n_samples, n_features)`. - n_features = array.shape[1] - largest_value_per_column = array[ - sorted_idx[-1, ...], xp.arange(n_features, device=device) - ] - # NaN values get sorted to end (largest value) - if xp.any(xp.isnan(largest_value_per_column)): - sorted_nan_mask = xp.take_along_axis(xp.isnan(array), sorted_idx, axis=0) - sorted_weights[sorted_nan_mask] = 0 - - # Compute the weighted cumulative distribution function (CDF) based on - # `sample_weight` and scale `percentile_rank` along it. - # - # Note: we call `xp.cumulative_sum` on the transposed `sorted_weights` to - # ensure that the result is of shape `(n_features, n_samples)` so - # `xp.searchsorted` calls take contiguous inputs as a result (for - # performance reasons). - weight_cdf = xp.cumulative_sum(sorted_weights.T, axis=1) - adjusted_percentile_rank = percentile_rank / 100 * weight_cdf[..., -1] - - # Ignore leading `sample_weight=0` observations when `percentile_rank=0` (#20528) - mask = adjusted_percentile_rank == 0 - adjusted_percentile_rank[mask] = xp.nextafter( - adjusted_percentile_rank[mask], adjusted_percentile_rank[mask] + 1 - ) - # For each feature with index j, find sample index i of the scalar value - # `adjusted_percentile_rank[j]` in 1D array `weight_cdf[j]`, such that: - # weight_cdf[j, i-1] < adjusted_percentile_rank[j] <= weight_cdf[j, i]. - # Note `searchsorted` defaults to equality on the right, whereas Hyndman and Fan - # reference equation has equality on the left. - percentile_indices = xp.stack( - [ - xp.searchsorted( - weight_cdf[feature_idx, ...], adjusted_percentile_rank[feature_idx] - ) - for feature_idx in range(weight_cdf.shape[0]) - ], - ) - # `percentile_indices` may be equal to `sorted_idx.shape[0]` due to floating - # point error (see #11813) - max_idx = sorted_idx.shape[0] - 1 - percentile_indices = xp.clip(percentile_indices, 0, max_idx) - - col_indices = xp.arange(array.shape[1], device=device) - percentile_in_sorted = sorted_idx[percentile_indices, col_indices] - - if average: - # From Hyndman and Fan (1996), `fraction_above` is `g` - fraction_above = ( - weight_cdf[col_indices, percentile_indices] - adjusted_percentile_rank - ) - is_fraction_above = fraction_above > xp.finfo(floating_dtype).eps - percentile_plus_one_indices = xp.clip(percentile_indices + 1, 0, max_idx) - percentile_plus_one_in_sorted = sorted_idx[ - percentile_plus_one_indices, col_indices - ] - # Handle case when next index ('plus one') has sample weight of 0 - zero_weight_cols = col_indices[ - sample_weight[percentile_plus_one_in_sorted, col_indices] == 0 - ] - for col_idx in zero_weight_cols: - cdf_val = weight_cdf[col_idx, percentile_indices[col_idx]] - # Search for next index where `weighted_cdf` is greater - next_index = xp.searchsorted( - weight_cdf[col_idx, ...], cdf_val, side="right" - ) - # Handle case where there are trailing 0 sample weight samples - # and `percentile_indices` is already max index - if next_index >= max_idx: - # use original `percentile_indices` again - next_index = percentile_indices[col_idx] - - percentile_plus_one_in_sorted[col_idx] = sorted_idx[next_index, col_idx] - - result = xp.where( - is_fraction_above, - array[percentile_in_sorted, col_indices], - ( - array[percentile_in_sorted, col_indices] - + array[percentile_plus_one_in_sorted, col_indices] - ) - / 2, - ) - else: - result = array[percentile_in_sorted, col_indices] - - return result[0] if n_dim == 1 else result - - -def _weighted_percentile_inner(x, w, target_sums, out, average, xp): - if x.size == 1: - out[:] = x - return - i = x.size // 2 - partitioner = xp.argpartition(x, x.size // 2) - w_left = w[partitioner[:i]] - sum_left = xp.sum(w_left) - j = xp.searchsorted(target_sums, sum_left) - target_sums[j:] -= sum_left - x_left = None - if j > 0: - x_left = x[partitioner[:i]] - _weighted_percentile_inner( - x_left, w_left, target_sums[:j], out[:j], average, xp - ) - if j < target_sums.size: - idx_0 = xp.searchsorted(target_sums[j:], 0, side="right") - if idx_0 > 0: - x_left = x[partitioner[:i]] if x_left is None else x_left - out[j : j + idx_0] = ( - (x_left.max() + x[partitioner[i:]].min()) / 2 - if average - else x_left.max() - ) - j += idx_0 - if j < target_sums.size: - x_right = x[partitioner[i:]] - w_right = w[partitioner[i:]] - _weighted_percentile_inner( - x_right, w_right, target_sums[j:], out[j:], average, xp - ) - - -def _weighted_percentile( - array, sample_weight, percentile_rank=50, average=False, xp=None -): - xp, _, device = get_namespace_and_device(array) # `sample_weight` should follow `array` for dtypes # XXX: ^ why? Also: why floating dtype? (is this really floating, I don't think so) floating_dtype = _find_matching_floating_dtype(array, xp=xp) array = xp.asarray(array, dtype=floating_dtype, device=device) sample_weight = xp.asarray(sample_weight, dtype=floating_dtype, device=device) + percentile_rank = xp.asarray(percentile_rank, dtype=floating_dtype, device=device) n_dim = array.ndim if n_dim == 0: @@ -242,33 +134,46 @@ def _weighted_percentile( array = xp.reshape(array, (-1, 1)) n_features = array.shape[1] - n_dim_percentile = 0 - q = xp.asarray([percentile_rank / 100]) + n_dim_percentile = percentile_rank.ndim + if n_dim_percentile == 0: + percentile_rank = xp.reshape(percentile_rank, (1,)) + q = percentile_rank / 100 + + # Sort quantiles for efficient processing in __weighted_percentile_inner sorter = xp.argsort(q, stable=False) result = xp.empty((n_features, q.size), dtype=floating_dtype) sorted_q = q[sorter] result_sorted = xp.empty(q.size, dtype=floating_dtype) + # Compute weighted percentiles for each feature (column) for feature_idx in range(n_features): x = array[..., feature_idx] + # Ignore NaN values by masking them out mask_nnan = ~xp.isnan(x) x = x[mask_nnan] if x.size == 0: + # If all values are NaN, return NaN for this feature result[feature_idx, ...] = xp.nan continue + # Select weights for non-NaN values w = ( sample_weight[mask_nnan, feature_idx] if sample_weight.ndim == 2 else sample_weight[mask_nnan] ) + # Ignore zero weights mask_nz = w != 0 - if not mask_nz.all(): + has_zero = not mask_nz.all() + if has_zero: w = w[mask_nz] - x[mask_nz] weights_sum = xp.sum(w) if weights_sum == 0: + # If all weights are zero, return max value (consistent with NaN handling) result[feature_idx, ...] = xp.max(x) continue + if has_zero: + x = x[mask_nz] + # Recursively compute weighted percentiles using partitioning _weighted_percentile_inner( x, w, @@ -277,6 +182,7 @@ def _weighted_percentile( average=average, xp=xp, ) + # Store results in original quantile order result[feature_idx, sorter] = result_sorted if n_dim_percentile == 0: From c82c75fc27b08a3ac7ec86ce3b197e7d9d817fa0 Mon Sep 17 00:00:00 2001 From: Arthur Date: Sun, 28 Sep 2025 14:52:56 +0200 Subject: [PATCH 05/22] swap functions order for easier diff --- sklearn/utils/stats.py | 72 +++++++++++++++++++++--------------------- 1 file changed, 36 insertions(+), 36 deletions(-) diff --git a/sklearn/utils/stats.py b/sklearn/utils/stats.py index 9165ad73e4a88..30f9bf1aee940 100644 --- a/sklearn/utils/stats.py +++ b/sklearn/utils/stats.py @@ -7,42 +7,6 @@ ) -def _weighted_percentile_inner(x, w, target_sums, out, average, xp): - if x.size == 1: - out[:] = x - return - i = x.size // 2 - partitioner = xp.argpartition(x, i) - w_left = w[partitioner[:i]] - sum_left = xp.sum(w_left) - j = xp.searchsorted(target_sums, sum_left) - target_sums[j:] -= sum_left - x_left = None - if j > 0: - # some quantiles are to be found on the left side of the partition - x_left = x[partitioner[:i]] - _weighted_percentile_inner( - x_left, w_left, target_sums[:j], out[:j], average, xp - ) - if j >= target_sums.size: - return - idx_0 = xp.searchsorted(target_sums[j:], 0, side="right") - if idx_0 > 0: - # some quantiles are precisely at the index of the partition - x_left = x[partitioner[:i]] if x_left is None else x_left - out[j : j + idx_0] = ( - (x_left.max() + x[partitioner[i:]].min()) / 2 if average else x_left.max() - ) - j += idx_0 - if j < target_sums.size: - # some quantiles are to be found on the right side of the partition - x_right = x[partitioner[i:]] - w_right = w[partitioner[i:]] - _weighted_percentile_inner( - x_right, w_right, target_sums[j:], out[j:], average, xp - ) - - def _weighted_percentile( array, sample_weight, percentile_rank=50, average=False, xp=None ): @@ -189,3 +153,39 @@ def _weighted_percentile( result = result[..., 0] return result[0] if n_dim == 1 else result + + +def _weighted_percentile_inner(x, w, target_sums, out, average, xp): + if x.size == 1: + out[:] = x + return + i = x.size // 2 + partitioner = xp.argpartition(x, i) + w_left = w[partitioner[:i]] + sum_left = xp.sum(w_left) + j = xp.searchsorted(target_sums, sum_left) + target_sums[j:] -= sum_left + x_left = None + if j > 0: + # some quantiles are to be found on the left side of the partition + x_left = x[partitioner[:i]] + _weighted_percentile_inner( + x_left, w_left, target_sums[:j], out[:j], average, xp + ) + if j >= target_sums.size: + return + idx_0 = xp.searchsorted(target_sums[j:], 0, side="right") + if idx_0 > 0: + # some quantiles are precisely at the index of the partition + x_left = x[partitioner[:i]] if x_left is None else x_left + out[j : j + idx_0] = ( + (x_left.max() + x[partitioner[i:]].min()) / 2 if average else x_left.max() + ) + j += idx_0 + if j < target_sums.size: + # some quantiles are to be found on the right side of the partition + x_right = x[partitioner[i:]] + w_right = w[partitioner[i:]] + _weighted_percentile_inner( + x_right, w_right, target_sums[j:], out[j:], average, xp + ) From f6f877c2574bf77fe6f6c083063208120a5918bc Mon Sep 17 00:00:00 2001 From: Arthur Date: Sun, 28 Sep 2025 15:09:43 +0200 Subject: [PATCH 06/22] use new signature where useful --- sklearn/dummy.py | 16 ++++++---------- sklearn/preprocessing/_discretization.py | 16 ++-------------- sklearn/preprocessing/_polynomial.py | 7 +------ sklearn/utils/stats.py | 1 - 4 files changed, 9 insertions(+), 31 deletions(-) diff --git a/sklearn/dummy.py b/sklearn/dummy.py index 2eab0e53e2aa6..f0823567abd9e 100644 --- a/sklearn/dummy.py +++ b/sklearn/dummy.py @@ -581,10 +581,9 @@ def fit(self, X, y, sample_weight=None): if sample_weight is None: self.constant_ = np.median(y, axis=0) else: - self.constant_ = [ - _weighted_percentile(y[:, k], sample_weight, percentile_rank=50.0) - for k in range(self.n_outputs_) - ] + self.constant_ = _weighted_percentile( + y, sample_weight, percentile_rank=50.0 + ) elif self.strategy == "quantile": if self.quantile is None: @@ -596,12 +595,9 @@ def fit(self, X, y, sample_weight=None): if sample_weight is None: self.constant_ = np.percentile(y, axis=0, q=percentile_rank) else: - self.constant_ = [ - _weighted_percentile( - y[:, k], sample_weight, percentile_rank=percentile_rank - ) - for k in range(self.n_outputs_) - ] + self.constant_ = _weighted_percentile( + y, sample_weight, percentile_rank=percentile_rank + ) elif self.strategy == "constant": if self.constant is None: diff --git a/sklearn/preprocessing/_discretization.py b/sklearn/preprocessing/_discretization.py index 5ab6fdd4b6576..847c388599821 100644 --- a/sklearn/preprocessing/_discretization.py +++ b/sklearn/preprocessing/_discretization.py @@ -365,23 +365,11 @@ def fit(self, X, y=None, sample_weight=None): dtype=np.float64, ) else: - # TODO: make _weighted_percentile accept an array of - # quantiles instead of calling it multiple times and - # sorting the column multiple times as a result. average = ( True if quantile_method == "averaged_inverted_cdf" else False ) - bin_edges[jj] = np.asarray( - [ - _weighted_percentile( - column, - sample_weight, - percentile_rank=p, - average=average, - ) - for p in percentile_levels - ], - dtype=np.float64, + bin_edges[jj] = _weighted_percentile( + column, sample_weight, percentile_levels, average=average ) elif self.strategy == "kmeans": from sklearn.cluster import KMeans # fixes import loops diff --git a/sklearn/preprocessing/_polynomial.py b/sklearn/preprocessing/_polynomial.py index acc2aa1138b68..e34b25fbbdd88 100644 --- a/sklearn/preprocessing/_polynomial.py +++ b/sklearn/preprocessing/_polynomial.py @@ -791,12 +791,7 @@ def _get_base_knot_positions(X, n_knots=10, knots="uniform", sample_weight=None) if sample_weight is None: knots = np.nanpercentile(X, percentile_ranks, axis=0) else: - knots = np.array( - [ - _weighted_percentile(X, sample_weight, percentile_rank) - for percentile_rank in percentile_ranks - ] - ) + knots = _weighted_percentile(X, sample_weight, percentile_ranks).T else: # knots == 'uniform': diff --git a/sklearn/utils/stats.py b/sklearn/utils/stats.py index 30f9bf1aee940..d3c5aa9390196 100644 --- a/sklearn/utils/stats.py +++ b/sklearn/utils/stats.py @@ -83,7 +83,6 @@ def _weighted_percentile( Weighted percentile at the requested probability level. """ xp, _, device = get_namespace_and_device(array) - # `sample_weight` should follow `array` for dtypes # XXX: ^ why? Also: why floating dtype? (is this really floating, I don't think so) floating_dtype = _find_matching_floating_dtype(array, xp=xp) From cad861414c67ca8b822f425cf05ec1dd53950108 Mon Sep 17 00:00:00 2001 From: Arthur Date: Sun, 28 Sep 2025 15:18:23 +0200 Subject: [PATCH 07/22] update docstring for new signature --- sklearn/utils/stats.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/sklearn/utils/stats.py b/sklearn/utils/stats.py index d3c5aa9390196..e5b05846194f8 100644 --- a/sklearn/utils/stats.py +++ b/sklearn/utils/stats.py @@ -62,9 +62,9 @@ def _weighted_percentile( Weights for each value in `array`. Must be same shape as `array` or of shape `(array.shape[0],)`. - percentile_rank: int or float, default=50 - The probability level of the percentile to compute, in percent. Must be between - 0 and 100. + percentile_rank: scalar or 1D array, default=50 + The probability level(s) of the percentile(s) to compute, in percent. Must be + between 0 and 100. If a 1D array, computes multiple percentiles. average : bool, default=False If `True`, uses the "averaged_inverted_cdf" quantile method, otherwise @@ -79,8 +79,15 @@ def _weighted_percentile( Returns ------- - percentile : scalar or 0D array if `array` 1D (or 0D), array if `array` 2D - Weighted percentile at the requested probability level. + percentile : scalar, 1D array, or 2D array + Weighted percentile at the requested probability level(s). + If `array` is 1D and `percentile_rank` is scalar, returns a scalar. + If `array` is 2D and `percentile_rank` is scalar, returns a 1D array + of shape `(array.shape[1],)` + If `array` is 1D and `percentile_rank` is 1D, returns a 1D array + of shape `(percentile_rank,)` + If `array` is 2D and `percentile_rank` is 1D, returns a 2D array + of shape `(array.shape[1], percentile_rank.size)` """ xp, _, device = get_namespace_and_device(array) # `sample_weight` should follow `array` for dtypes From 7f5d47f4f051628f23d9da198f857f6ba47d9b9b Mon Sep 17 00:00:00 2001 From: Arthur Date: Sun, 28 Sep 2025 15:49:32 +0200 Subject: [PATCH 08/22] adapt fully to array-API --- sklearn/utils/stats.py | 42 ++++++++++++++++++++++++++++-------------- 1 file changed, 28 insertions(+), 14 deletions(-) diff --git a/sklearn/utils/stats.py b/sklearn/utils/stats.py index e5b05846194f8..0ece20023f500 100644 --- a/sklearn/utils/stats.py +++ b/sklearn/utils/stats.py @@ -110,9 +110,9 @@ def _weighted_percentile( q = percentile_rank / 100 # Sort quantiles for efficient processing in __weighted_percentile_inner - sorter = xp.argsort(q, stable=False) + q_sorter = xp.argsort(q, stable=False) result = xp.empty((n_features, q.size), dtype=floating_dtype) - sorted_q = q[sorter] + sorted_q = q[q_sorter] result_sorted = xp.empty(q.size, dtype=floating_dtype) # Compute weighted percentiles for each feature (column) @@ -133,7 +133,7 @@ def _weighted_percentile( ) # Ignore zero weights mask_nz = w != 0 - has_zero = not mask_nz.all() + has_zero = not xp.all(mask_nz) if has_zero: w = w[mask_nz] weights_sum = xp.sum(w) @@ -144,16 +144,23 @@ def _weighted_percentile( if has_zero: x = x[mask_nz] # Recursively compute weighted percentiles using partitioning + w_sorted = False + if not hasattr(xp, "argpartition"): + x_sorter = xp.argsort(x, stable=False) + w = w[x_sorter] + x = x[x_sorter] + w_sorted = True _weighted_percentile_inner( x, w, target_sums=weights_sum * sorted_q, out=result_sorted, average=average, + w_sorted=w_sorted, xp=xp, ) # Store results in original quantile order - result[feature_idx, sorter] = result_sorted + result[feature_idx, q_sorter] = result_sorted if n_dim_percentile == 0: result = result[..., 0] @@ -161,22 +168,28 @@ def _weighted_percentile( return result[0] if n_dim == 1 else result -def _weighted_percentile_inner(x, w, target_sums, out, average, xp): +def _weighted_percentile_inner(x, w, target_sums, out, average, w_sorted, xp): if x.size == 1: out[:] = x return i = x.size // 2 - partitioner = xp.argpartition(x, i) - w_left = w[partitioner[:i]] + if w_sorted: + w_left = w[:i] + x_left = x[:i] + w_right = w[i:] + x_right = x[i:] + else: + partitioner = xp.argpartition(x, i) + w_left = w[partitioner[:i]] + x_left = w_right = x_right = None sum_left = xp.sum(w_left) j = xp.searchsorted(target_sums, sum_left) target_sums[j:] -= sum_left - x_left = None if j > 0: # some quantiles are to be found on the left side of the partition - x_left = x[partitioner[:i]] + x_left = x[partitioner[:i]] if x_left is None else x_left _weighted_percentile_inner( - x_left, w_left, target_sums[:j], out[:j], average, xp + x_left, w_left, target_sums[:j], out[:j], average, w_sorted, xp ) if j >= target_sums.size: return @@ -184,14 +197,15 @@ def _weighted_percentile_inner(x, w, target_sums, out, average, xp): if idx_0 > 0: # some quantiles are precisely at the index of the partition x_left = x[partitioner[:i]] if x_left is None else x_left + x_right = x[partitioner[i:]] if x_right is None else x_right out[j : j + idx_0] = ( - (x_left.max() + x[partitioner[i:]].min()) / 2 if average else x_left.max() + (xp.max(x_left) + xp.min(x_right)) / 2 if average else xp.max(x_left) ) j += idx_0 if j < target_sums.size: # some quantiles are to be found on the right side of the partition - x_right = x[partitioner[i:]] - w_right = w[partitioner[i:]] + x_right = x[partitioner[i:]] if x_right is None else x_right + w_right = w[partitioner[i:]] if w_right is None else w_right _weighted_percentile_inner( - x_right, w_right, target_sums[j:], out[j:], average, xp + x_right, w_right, target_sums[j:], out[j:], average, w_sorted, xp ) From 649b2711390916381c04bd39311094306802734f Mon Sep 17 00:00:00 2001 From: Arthur Date: Sun, 28 Sep 2025 16:21:52 +0200 Subject: [PATCH 09/22] fix array API compat --- sklearn/utils/stats.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/sklearn/utils/stats.py b/sklearn/utils/stats.py index 0ece20023f500..0d838bd9c5ef8 100644 --- a/sklearn/utils/stats.py +++ b/sklearn/utils/stats.py @@ -108,12 +108,13 @@ def _weighted_percentile( if n_dim_percentile == 0: percentile_rank = xp.reshape(percentile_rank, (1,)) q = percentile_rank / 100 + n_percentiles = percentile_rank.shape[0] # Sort quantiles for efficient processing in __weighted_percentile_inner q_sorter = xp.argsort(q, stable=False) - result = xp.empty((n_features, q.size), dtype=floating_dtype) + result = xp.empty((n_features, n_percentiles), dtype=floating_dtype) sorted_q = q[q_sorter] - result_sorted = xp.empty(q.size, dtype=floating_dtype) + result_sorted = xp.empty((n_percentiles,), dtype=floating_dtype) # Compute weighted percentiles for each feature (column) for feature_idx in range(n_features): @@ -121,7 +122,7 @@ def _weighted_percentile( # Ignore NaN values by masking them out mask_nnan = ~xp.isnan(x) x = x[mask_nnan] - if x.size == 0: + if len(x) == 0: # If all values are NaN, return NaN for this feature result[feature_idx, ...] = xp.nan continue @@ -169,10 +170,11 @@ def _weighted_percentile( def _weighted_percentile_inner(x, w, target_sums, out, average, w_sorted, xp): - if x.size == 1: + n = x.shape[0] + if n == 1: out[:] = x return - i = x.size // 2 + i = n // 2 if w_sorted: w_left = w[:i] x_left = x[:i] @@ -191,7 +193,7 @@ def _weighted_percentile_inner(x, w, target_sums, out, average, w_sorted, xp): _weighted_percentile_inner( x_left, w_left, target_sums[:j], out[:j], average, w_sorted, xp ) - if j >= target_sums.size: + if j >= target_sums.shape[0]: return idx_0 = xp.searchsorted(target_sums[j:], 0, side="right") if idx_0 > 0: @@ -202,7 +204,7 @@ def _weighted_percentile_inner(x, w, target_sums, out, average, w_sorted, xp): (xp.max(x_left) + xp.min(x_right)) / 2 if average else xp.max(x_left) ) j += idx_0 - if j < target_sums.size: + if j < target_sums.shape[0]: # some quantiles are to be found on the right side of the partition x_right = x[partitioner[i:]] if x_right is None else x_right w_right = w[partitioner[i:]] if w_right is None else w_right From cda231c4e97e187010048ce7e2ac7a315515f2e1 Mon Sep 17 00:00:00 2001 From: Arthur Date: Sun, 28 Sep 2025 18:43:21 +0200 Subject: [PATCH 10/22] another array API fix: TypeError: object of type 'Array' has no len() --- sklearn/utils/stats.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sklearn/utils/stats.py b/sklearn/utils/stats.py index 0d838bd9c5ef8..ef8bae3a18f90 100644 --- a/sklearn/utils/stats.py +++ b/sklearn/utils/stats.py @@ -85,9 +85,9 @@ def _weighted_percentile( If `array` is 2D and `percentile_rank` is scalar, returns a 1D array of shape `(array.shape[1],)` If `array` is 1D and `percentile_rank` is 1D, returns a 1D array - of shape `(percentile_rank,)` + of shape `(percentile_rank.shape[0],)` If `array` is 2D and `percentile_rank` is 1D, returns a 2D array - of shape `(array.shape[1], percentile_rank.size)` + of shape `(array.shape[1], percentile_rank.shape[0])` """ xp, _, device = get_namespace_and_device(array) # `sample_weight` should follow `array` for dtypes @@ -122,7 +122,7 @@ def _weighted_percentile( # Ignore NaN values by masking them out mask_nnan = ~xp.isnan(x) x = x[mask_nnan] - if len(x) == 0: + if x.shape[0] == 0: # If all values are NaN, return NaN for this feature result[feature_idx, ...] = xp.nan continue From 9c4a5adbc87fe58df8c585e468f4a7448994dd93 Mon Sep 17 00:00:00 2001 From: Arthur Date: Sun, 28 Sep 2025 19:37:51 +0200 Subject: [PATCH 11/22] more array-API fixes; tested locally; but I cant test everything I dont have a nvidia-GPU --- sklearn/utils/stats.py | 26 ++++++++++++-------------- 1 file changed, 12 insertions(+), 14 deletions(-) diff --git a/sklearn/utils/stats.py b/sklearn/utils/stats.py index ef8bae3a18f90..d2eaf00ff4fb2 100644 --- a/sklearn/utils/stats.py +++ b/sklearn/utils/stats.py @@ -107,30 +107,28 @@ def _weighted_percentile( n_dim_percentile = percentile_rank.ndim if n_dim_percentile == 0: percentile_rank = xp.reshape(percentile_rank, (1,)) + if xp.any(percentile_rank[1:] < percentile_rank[:-1]): + raise ValueError("percentile_rank must be non-decreasing") + q = percentile_rank / 100 n_percentiles = percentile_rank.shape[0] - - # Sort quantiles for efficient processing in __weighted_percentile_inner - q_sorter = xp.argsort(q, stable=False) result = xp.empty((n_features, n_percentiles), dtype=floating_dtype) - sorted_q = q[q_sorter] - result_sorted = xp.empty((n_percentiles,), dtype=floating_dtype) # Compute weighted percentiles for each feature (column) for feature_idx in range(n_features): x = array[..., feature_idx] # Ignore NaN values by masking them out - mask_nnan = ~xp.isnan(x) - x = x[mask_nnan] + mask_not_nan = ~xp.isnan(x) + x = x[mask_not_nan] if x.shape[0] == 0: # If all values are NaN, return NaN for this feature result[feature_idx, ...] = xp.nan continue # Select weights for non-NaN values w = ( - sample_weight[mask_nnan, feature_idx] + sample_weight[..., feature_idx][mask_not_nan] if sample_weight.ndim == 2 - else sample_weight[mask_nnan] + else sample_weight[mask_not_nan] ) # Ignore zero weights mask_nz = w != 0 @@ -154,14 +152,12 @@ def _weighted_percentile( _weighted_percentile_inner( x, w, - target_sums=weights_sum * sorted_q, - out=result_sorted, + target_sums=weights_sum * q, + out=result[feature_idx, ...], average=average, w_sorted=w_sorted, xp=xp, ) - # Store results in original quantile order - result[feature_idx, q_sorter] = result_sorted if n_dim_percentile == 0: result = result[..., 0] @@ -195,7 +191,9 @@ def _weighted_percentile_inner(x, w, target_sums, out, average, w_sorted, xp): ) if j >= target_sums.shape[0]: return - idx_0 = xp.searchsorted(target_sums[j:], 0, side="right") + zero = xp.sum(target_sums[:0]) + # ^ array API needs an array as argument for searchsorted + idx_0 = xp.searchsorted(target_sums[j:], zero, side="right") if idx_0 > 0: # some quantiles are precisely at the index of the partition x_left = x[partitioner[:i]] if x_left is None else x_left From 4ea221e568d4cf8ed31d72701ebecc84bffc2f00 Mon Sep 17 00:00:00 2001 From: Arthur Date: Mon, 29 Sep 2025 12:34:06 +0200 Subject: [PATCH 12/22] tmp: old for benchmark --- sklearn/utils/stats.py | 183 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 183 insertions(+) diff --git a/sklearn/utils/stats.py b/sklearn/utils/stats.py index d2eaf00ff4fb2..f43c80182dfc3 100644 --- a/sklearn/utils/stats.py +++ b/sklearn/utils/stats.py @@ -209,3 +209,186 @@ def _weighted_percentile_inner(x, w, target_sums, out, average, w_sorted, xp): _weighted_percentile_inner( x_right, w_right, target_sums[j:], out[j:], average, w_sorted, xp ) + + +def _old_weighted_percentile( + array, sample_weight, percentile_rank=50, average=False, xp=None +): + """Compute the weighted percentile. + + Implement an array API compatible (weighted version) of NumPy's 'inverted_cdf' + method when `average=False` (default) and 'averaged_inverted_cdf' when + `average=True`. + + For an array ordered by increasing values, when the percentile lies exactly on a + data point: + + * 'inverted_cdf' takes the exact data point. + * 'averaged_inverted_cdf' takes the average of the exact data point and the one + above it (this means it gives the same result as `median` for unit weights). + + E.g., for the array [1, 2, 3, 4] the percentile rank at each data point would + be [25, 50, 75, 100]. Percentile rank 50 lies on '2'. 'average_inverted_cdf' + computes the average of '2' and '3', making it 'symmetrical' because if you + reverse the array, rank 50 would fall on '3'. It also matches 'median'. + On the other hand, 'inverted_cdf', which does not satisfy the symmetry property, + would give '2'. + + When the requested percentile lies between two data points, both methods return + the higher data point. + E.g., for the array [1, 2, 3, 4, 5] the percentile rank at each data point would + be [20, 40, 60, 80, 100]. Percentile rank 50, lies between '2' and '3'. Taking the + higher data point is symmetrical because if you reverse the array, 50 would lie + between '4' and '3'. Both methods match median in this case. + + If `array` is a 2D array, the `values` are selected along axis 0. + + `NaN` values are ignored by setting their weights to 0. If `array` is 2D, this + is done in a column-isolated manner: a `NaN` in the second column, does not impact + the percentile computed for the first column even if `sample_weight` is 1D. + + .. versionchanged:: 0.24 + Accepts 2D `array`. + + .. versionchanged:: 1.7 + Supports handling of `NaN` values. + + .. versionchanged:: 1.8 + Supports `average`, which calculates percentile using the + "averaged_inverted_cdf" method. + + Parameters + ---------- + array : 1D or 2D array + Values to take the weighted percentile of. + + sample_weight: 1D or 2D array + Weights for each value in `array`. Must be same shape as `array` or of shape + `(array.shape[0],)`. + + percentile_rank: int or float, default=50 + The probability level of the percentile to compute, in percent. Must be between + 0 and 100. + + average : bool, default=False + If `True`, uses the "averaged_inverted_cdf" quantile method, otherwise + defaults to "inverted_cdf". "averaged_inverted_cdf" is symmetrical with + unit `sample_weight`, such that the total of `sample_weight` below or equal to + `_weighted_percentile(percentile_rank)` is the same as the total of + `sample_weight` above or equal to `_weighted_percentile(100-percentile_rank)`. + This symmetry is not guaranteed with non-unit weights. + + xp : array_namespace, default=None + The standard-compatible namespace for `array`. Default: infer. + + Returns + ------- + percentile : scalar or 0D array if `array` 1D (or 0D), array if `array` 2D + Weighted percentile at the requested probability level. + """ + xp, _, device = get_namespace_and_device(array) + # `sample_weight` should follow `array` for dtypes + floating_dtype = _find_matching_floating_dtype(array, xp=xp) + array = xp.asarray(array, dtype=floating_dtype, device=device) + sample_weight = xp.asarray(sample_weight, dtype=floating_dtype, device=device) + + n_dim = array.ndim + if n_dim == 0: + return array + if array.ndim == 1: + array = xp.reshape(array, (-1, 1)) + # When sample_weight 1D, repeat for each array.shape[1] + if array.shape != sample_weight.shape and array.shape[0] == sample_weight.shape[0]: + sample_weight = xp.tile(sample_weight, (array.shape[1], 1)).T + # Sort `array` and `sample_weight` along axis=0: + sorted_idx = xp.argsort(array, axis=0, stable=False) + sorted_weights = xp.take_along_axis(sample_weight, sorted_idx, axis=0) + + # Set NaN values in `sample_weight` to 0. Only perform this operation if NaN + # values present to avoid temporary allocations of size `(n_samples, n_features)`. + n_features = array.shape[1] + largest_value_per_column = array[ + sorted_idx[-1, ...], xp.arange(n_features, device=device) + ] + # NaN values get sorted to end (largest value) + if xp.any(xp.isnan(largest_value_per_column)): + sorted_nan_mask = xp.take_along_axis(xp.isnan(array), sorted_idx, axis=0) + sorted_weights[sorted_nan_mask] = 0 + + # Compute the weighted cumulative distribution function (CDF) based on + # `sample_weight` and scale `percentile_rank` along it. + # + # Note: we call `xp.cumulative_sum` on the transposed `sorted_weights` to + # ensure that the result is of shape `(n_features, n_samples)` so + # `xp.searchsorted` calls take contiguous inputs as a result (for + # performance reasons). + weight_cdf = xp.cumulative_sum(sorted_weights.T, axis=1) + adjusted_percentile_rank = percentile_rank / 100 * weight_cdf[..., -1] + + # Ignore leading `sample_weight=0` observations when `percentile_rank=0` (#20528) + mask = adjusted_percentile_rank == 0 + adjusted_percentile_rank[mask] = xp.nextafter( + adjusted_percentile_rank[mask], adjusted_percentile_rank[mask] + 1 + ) + # For each feature with index j, find sample index i of the scalar value + # `adjusted_percentile_rank[j]` in 1D array `weight_cdf[j]`, such that: + # weight_cdf[j, i-1] < adjusted_percentile_rank[j] <= weight_cdf[j, i]. + # Note `searchsorted` defaults to equality on the right, whereas Hyndman and Fan + # reference equation has equality on the left. + percentile_indices = xp.stack( + [ + xp.searchsorted( + weight_cdf[feature_idx, ...], adjusted_percentile_rank[feature_idx] + ) + for feature_idx in range(weight_cdf.shape[0]) + ], + ) + # `percentile_indices` may be equal to `sorted_idx.shape[0]` due to floating + # point error (see #11813) + max_idx = sorted_idx.shape[0] - 1 + percentile_indices = xp.clip(percentile_indices, 0, max_idx) + + col_indices = xp.arange(array.shape[1], device=device) + percentile_in_sorted = sorted_idx[percentile_indices, col_indices] + + if average: + # From Hyndman and Fan (1996), `fraction_above` is `g` + fraction_above = ( + weight_cdf[col_indices, percentile_indices] - adjusted_percentile_rank + ) + is_fraction_above = fraction_above > xp.finfo(floating_dtype).eps + percentile_plus_one_indices = xp.clip(percentile_indices + 1, 0, max_idx) + percentile_plus_one_in_sorted = sorted_idx[ + percentile_plus_one_indices, col_indices + ] + # Handle case when next index ('plus one') has sample weight of 0 + zero_weight_cols = col_indices[ + sample_weight[percentile_plus_one_in_sorted, col_indices] == 0 + ] + for col_idx in zero_weight_cols: + cdf_val = weight_cdf[col_idx, percentile_indices[col_idx]] + # Search for next index where `weighted_cdf` is greater + next_index = xp.searchsorted( + weight_cdf[col_idx, ...], cdf_val, side="right" + ) + # Handle case where there are trailing 0 sample weight samples + # and `percentile_indices` is already max index + if next_index >= max_idx: + # use original `percentile_indices` again + next_index = percentile_indices[col_idx] + + percentile_plus_one_in_sorted[col_idx] = sorted_idx[next_index, col_idx] + + result = xp.where( + is_fraction_above, + array[percentile_in_sorted, col_indices], + ( + array[percentile_in_sorted, col_indices] + + array[percentile_plus_one_in_sorted, col_indices] + ) + / 2, + ) + else: + result = array[percentile_in_sorted, col_indices] + + return result[0] if n_dim == 1 else result From f96d3343e49414843aa18c3d3fff77049cff5383 Mon Sep 17 00:00:00 2001 From: Arthur Date: Thu, 9 Oct 2025 12:42:18 +0200 Subject: [PATCH 13/22] remove comment about floating dtype --- sklearn/utils/stats.py | 1 - 1 file changed, 1 deletion(-) diff --git a/sklearn/utils/stats.py b/sklearn/utils/stats.py index 1c60ffc13508c..dcf513dc3d93c 100644 --- a/sklearn/utils/stats.py +++ b/sklearn/utils/stats.py @@ -91,7 +91,6 @@ def _weighted_percentile( """ xp, _, device = get_namespace_and_device(array) # `sample_weight` should follow `array` for dtypes - # XXX: ^ why? Also: why floating dtype? (is this really floating, I don't think so) floating_dtype = _find_matching_floating_dtype(array, xp=xp) array = xp.asarray(array, dtype=floating_dtype, device=device) sample_weight = xp.asarray(sample_weight, dtype=floating_dtype, device=device) From 013725da09113a55cda1ce49836b97f294c68b21 Mon Sep 17 00:00:00 2001 From: Arthur Lacote Date: Thu, 9 Oct 2025 12:44:59 +0200 Subject: [PATCH 14/22] Fix device error Co-authored-by: Olivier Grisel --- sklearn/utils/stats.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/utils/stats.py b/sklearn/utils/stats.py index dcf513dc3d93c..f875f54958ff0 100644 --- a/sklearn/utils/stats.py +++ b/sklearn/utils/stats.py @@ -111,7 +111,7 @@ def _weighted_percentile( q = percentile_rank / 100 n_percentiles = percentile_rank.shape[0] - result = xp.empty((n_features, n_percentiles), dtype=floating_dtype) + result = xp.empty((n_features, n_percentiles), dtype=floating_dtype, device=device) # Compute weighted percentiles for each feature (column) for feature_idx in range(n_features): From 7b9e50b689fb36cb74788f5145687406a0e02755 Mon Sep 17 00:00:00 2001 From: Arthur Date: Mon, 13 Oct 2025 17:00:37 +0200 Subject: [PATCH 15/22] mitigate perf loss with d>>1 --- sklearn/utils/stats.py | 66 +++++++++++++++++++++++++----------------- 1 file changed, 40 insertions(+), 26 deletions(-) diff --git a/sklearn/utils/stats.py b/sklearn/utils/stats.py index dcf513dc3d93c..bfda3cbe37e86 100644 --- a/sklearn/utils/stats.py +++ b/sklearn/utils/stats.py @@ -109,6 +109,11 @@ def _weighted_percentile( if xp.any(percentile_rank[1:] < percentile_rank[:-1]): raise ValueError("percentile_rank must be non-decreasing") + sorted_idx = None + w_sorted = not hasattr(xp, "argpartition") + if w_sorted: + sorted_idx = xp.argsort(array, axis=0, stable=False) + q = percentile_rank / 100 n_percentiles = percentile_rank.shape[0] result = xp.empty((n_features, n_percentiles), dtype=floating_dtype) @@ -116,6 +121,16 @@ def _weighted_percentile( # Compute weighted percentiles for each feature (column) for feature_idx in range(n_features): x = array[..., feature_idx] + w = ( + sample_weight[..., feature_idx] + if sample_weight.ndim == 2 + else sample_weight + ) + if w_sorted: + x_sorter = sorted_idx[..., feature_idx] + x = x[x_sorter] + w = w[x_sorter] + # Ignore NaN values by masking them out mask_not_nan = ~xp.isnan(x) x = x[mask_not_nan] @@ -124,11 +139,7 @@ def _weighted_percentile( result[feature_idx, ...] = xp.nan continue # Select weights for non-NaN values - w = ( - sample_weight[..., feature_idx][mask_not_nan] - if sample_weight.ndim == 2 - else sample_weight[mask_not_nan] - ) + w = w[mask_not_nan] # Ignore zero weights mask_nz = w != 0 has_zero = not xp.all(mask_nz) @@ -142,12 +153,6 @@ def _weighted_percentile( if has_zero: x = x[mask_nz] # Recursively compute weighted percentiles using partitioning - w_sorted = False - if not hasattr(xp, "argpartition"): - x_sorter = xp.argsort(x, stable=False) - w = w[x_sorter] - x = x[x_sorter] - w_sorted = True _weighted_percentile_inner( x, w, @@ -166,25 +171,24 @@ def _weighted_percentile( def _weighted_percentile_inner(x, w, target_sums, out, average, w_sorted, xp): n = x.shape[0] - if n == 1: - out[:] = x + if n < 100 and not w_sorted: + sorted_idx = xp.argsort(x, stable=False) + w = w[sorted_idx] + x = x[sorted_idx] + w_sorted = True + if w_sorted: + _weighted_percentile_inner_sorted(x, w, target_sums, out, average, xp) return i = n // 2 - if w_sorted: - w_left = w[:i] - x_left = x[:i] - w_right = w[i:] - x_right = x[i:] - else: - partitioner = xp.argpartition(x, i) - w_left = w[partitioner[:i]] - x_left = w_right = x_right = None + partitioner = xp.argpartition(x, i) + w_left = w[partitioner[:i]] + x_right = None sum_left = xp.sum(w_left) j = xp.searchsorted(target_sums, sum_left) target_sums[j:] -= sum_left if j > 0: # some quantiles are to be found on the left side of the partition - x_left = x[partitioner[:i]] if x_left is None else x_left + x_left = x[partitioner[:i]] _weighted_percentile_inner( x_left, w_left, target_sums[:j], out[:j], average, w_sorted, xp ) @@ -195,8 +199,8 @@ def _weighted_percentile_inner(x, w, target_sums, out, average, w_sorted, xp): idx_0 = xp.searchsorted(target_sums[j:], zero, side="right") if idx_0 > 0: # some quantiles are precisely at the index of the partition - x_left = x[partitioner[:i]] if x_left is None else x_left - x_right = x[partitioner[i:]] if x_right is None else x_right + x_left = x[partitioner[:i]] + x_right = x[partitioner[i:]] out[j : j + idx_0] = ( (xp.max(x_left) + xp.min(x_right)) / 2 if average else xp.max(x_left) ) @@ -204,12 +208,22 @@ def _weighted_percentile_inner(x, w, target_sums, out, average, w_sorted, xp): if j < target_sums.shape[0]: # some quantiles are to be found on the right side of the partition x_right = x[partitioner[i:]] if x_right is None else x_right - w_right = w[partitioner[i:]] if w_right is None else w_right + w_right = w[partitioner[i:]] _weighted_percentile_inner( x_right, w_right, target_sums[j:], out[j:], average, w_sorted, xp ) +def _weighted_percentile_inner_sorted(x, w, target_sums, out, average, xp): + cw = xp.cumsum(w, axis=0) + idx = xp.searchsorted(cw, target_sums) + out[:] = x[idx] + if average: + mask_0 = cw[idx] == target_sums + if mask_0.any(): + out[mask_0] = (x[idx] + x[idx + 1]) / 2 + + def _old_weighted_percentile( array, sample_weight, percentile_rank=50, average=False, xp=None ): From 35f6d8d0dd72c7a644812a4d9d76fe56d7245dd3 Mon Sep 17 00:00:00 2001 From: Arthur Date: Mon, 13 Oct 2025 17:03:53 +0200 Subject: [PATCH 16/22] minor fix for average --- sklearn/utils/stats.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/sklearn/utils/stats.py b/sklearn/utils/stats.py index 640355b91bf2d..b337172e8ed58 100644 --- a/sklearn/utils/stats.py +++ b/sklearn/utils/stats.py @@ -221,7 +221,10 @@ def _weighted_percentile_inner_sorted(x, w, target_sums, out, average, xp): if average: mask_0 = cw[idx] == target_sums if mask_0.any(): - out[mask_0] = (x[idx] + x[idx + 1]) / 2 + idx = idx[mask_0] + idx_p1 = idx + 1 + idx_p1 = xp.minimum(idx_p1, x.shape[0] - 1) + out[mask_0] = (x[idx] + x[idx_p1]) / 2 def _old_weighted_percentile( From fd8b2c95541a5df4ac099976247786d1590e8b9e Mon Sep 17 00:00:00 2001 From: Arthur Date: Fri, 17 Oct 2025 18:16:10 +0200 Subject: [PATCH 17/22] WIP: inner func handles 2D but only 1 quantile --- sklearn/utils/stats.py | 92 +++++++++++++++++------------------------- 1 file changed, 38 insertions(+), 54 deletions(-) diff --git a/sklearn/utils/stats.py b/sklearn/utils/stats.py index b337172e8ed58..89cfa3f8aca06 100644 --- a/sklearn/utils/stats.py +++ b/sklearn/utils/stats.py @@ -169,62 +169,46 @@ def _weighted_percentile( return result[0] if n_dim == 1 else result -def _weighted_percentile_inner(x, w, target_sums, out, average, w_sorted, xp): - n = x.shape[0] - if n < 100 and not w_sorted: - sorted_idx = xp.argsort(x, stable=False) - w = w[sorted_idx] - x = x[sorted_idx] - w_sorted = True - if w_sorted: - _weighted_percentile_inner_sorted(x, w, target_sums, out, average, xp) - return - i = n // 2 - partitioner = xp.argpartition(x, i) - w_left = w[partitioner[:i]] - x_right = None - sum_left = xp.sum(w_left) - j = xp.searchsorted(target_sums, sum_left) - target_sums[j:] -= sum_left - if j > 0: - # some quantiles are to be found on the left side of the partition - x_left = x[partitioner[:i]] - _weighted_percentile_inner( - x_left, w_left, target_sums[:j], out[:j], average, w_sorted, xp - ) - if j >= target_sums.shape[0]: - return - zero = xp.sum(target_sums[:0]) - # ^ array API needs an array as argument for searchsorted - idx_0 = xp.searchsorted(target_sums[j:], zero, side="right") - if idx_0 > 0: - # some quantiles are precisely at the index of the partition - x_left = x[partitioner[:i]] - x_right = x[partitioner[i:]] - out[j : j + idx_0] = ( - (xp.max(x_left) + xp.min(x_right)) / 2 if average else xp.max(x_left) - ) - j += idx_0 - if j < target_sums.shape[0]: - # some quantiles are to be found on the right side of the partition - x_right = x[partitioner[i:]] if x_right is None else x_right - w_right = w[partitioner[i:]] - _weighted_percentile_inner( - x_right, w_right, target_sums[j:], out[j:], average, w_sorted, xp - ) +def _weighted_percentile_inner(x, w, target_sums, average, w_sorted, xp): + """ + x: (d, n) + w: (d, n) + target_sums: (d,) + """ + d, n = x.shape + while n > 1: + i = (n + 1) // 2 + partitioner = xp.argpartition(x, i, axis=1) + w_left = xp.take_along_axis(w, partitioner[:, :i], axis=1) + sum_left = xp.sum(w_left, axis=1) -def _weighted_percentile_inner_sorted(x, w, target_sums, out, average, xp): - cw = xp.cumsum(w, axis=0) - idx = xp.searchsorted(cw, target_sums) - out[:] = x[idx] - if average: - mask_0 = cw[idx] == target_sums - if mask_0.any(): - idx = idx[mask_0] - idx_p1 = idx + 1 - idx_p1 = xp.minimum(idx_p1, x.shape[0] - 1) - out[mask_0] = (x[idx] + x[idx_p1]) / 2 + mask_exact = target_sums == sum_left + + if mask_exact.any(): + pass + + mask_go_left = target_sums < sum_left + mask_go_right = target_sums > sum_left + + x_next = xp.full_like(x, fill_value=xp.inf, shape=(d, i)) + w_next = xp.zeros_like(w, shape=(d, i)) + + target_sums[mask_go_left] -= sum_left + left_part = partitioner[mask_go_left][:, :i] + i_right = n - i + right_part = partitioner[mask_go_right][:, i_right:] + + x_next[mask_go_left] = xp.take_along_axis(x[mask_go_left], left_part, axis=1) + x_next[mask_go_right] = xp.take_along_axis(x[mask_go_right], right_part, axis=1) + w_next[mask_go_left] = w_left[mask_go_left] + w_next[mask_go_right] = xp.take_along_axis(w[mask_go_right], right_part, axis=1) + + x = x_next + w = w_next + n = i + + return x[:, 0] def _old_weighted_percentile( From d59abf53403e2b4fdf6907a4406c44215ce1e33f Mon Sep 17 00:00:00 2001 From: Arthur Date: Fri, 17 Oct 2025 18:48:45 +0200 Subject: [PATCH 18/22] restore back prev implem. and loop to compute multiple percentiles with only one sort/cumsum --- sklearn/utils/stats.py | 334 ++++++++++------------------------------- 1 file changed, 76 insertions(+), 258 deletions(-) diff --git a/sklearn/utils/stats.py b/sklearn/utils/stats.py index 89cfa3f8aca06..a2827fc503b8d 100644 --- a/sklearn/utils/stats.py +++ b/sklearn/utils/stats.py @@ -101,205 +101,14 @@ def _weighted_percentile( return array if array.ndim == 1: array = xp.reshape(array, (-1, 1)) + # When sample_weight 1D, repeat for each array.shape[1] + if array.shape != sample_weight.shape and array.shape[0] == sample_weight.shape[0]: + sample_weight = xp.tile(sample_weight, (array.shape[1], 1)).T - n_features = array.shape[1] n_dim_percentile = percentile_rank.ndim if n_dim_percentile == 0: percentile_rank = xp.reshape(percentile_rank, (1,)) - if xp.any(percentile_rank[1:] < percentile_rank[:-1]): - raise ValueError("percentile_rank must be non-decreasing") - - sorted_idx = None - w_sorted = not hasattr(xp, "argpartition") - if w_sorted: - sorted_idx = xp.argsort(array, axis=0, stable=False) - - q = percentile_rank / 100 - n_percentiles = percentile_rank.shape[0] - result = xp.empty((n_features, n_percentiles), dtype=floating_dtype, device=device) - - # Compute weighted percentiles for each feature (column) - for feature_idx in range(n_features): - x = array[..., feature_idx] - w = ( - sample_weight[..., feature_idx] - if sample_weight.ndim == 2 - else sample_weight - ) - if w_sorted: - x_sorter = sorted_idx[..., feature_idx] - x = x[x_sorter] - w = w[x_sorter] - - # Ignore NaN values by masking them out - mask_not_nan = ~xp.isnan(x) - x = x[mask_not_nan] - if x.shape[0] == 0: - # If all values are NaN, return NaN for this feature - result[feature_idx, ...] = xp.nan - continue - # Select weights for non-NaN values - w = w[mask_not_nan] - # Ignore zero weights - mask_nz = w != 0 - has_zero = not xp.all(mask_nz) - if has_zero: - w = w[mask_nz] - weights_sum = xp.sum(w) - if weights_sum == 0: - # If all weights are zero, return max value (consistent with NaN handling) - result[feature_idx, ...] = xp.max(x) - continue - if has_zero: - x = x[mask_nz] - # Recursively compute weighted percentiles using partitioning - _weighted_percentile_inner( - x, - w, - target_sums=weights_sum * q, - out=result[feature_idx, ...], - average=average, - w_sorted=w_sorted, - xp=xp, - ) - - if n_dim_percentile == 0: - result = result[..., 0] - - return result[0] if n_dim == 1 else result - - -def _weighted_percentile_inner(x, w, target_sums, average, w_sorted, xp): - """ - x: (d, n) - w: (d, n) - target_sums: (d,) - """ - d, n = x.shape - - while n > 1: - i = (n + 1) // 2 - partitioner = xp.argpartition(x, i, axis=1) - w_left = xp.take_along_axis(w, partitioner[:, :i], axis=1) - sum_left = xp.sum(w_left, axis=1) - - mask_exact = target_sums == sum_left - - if mask_exact.any(): - pass - - mask_go_left = target_sums < sum_left - mask_go_right = target_sums > sum_left - - x_next = xp.full_like(x, fill_value=xp.inf, shape=(d, i)) - w_next = xp.zeros_like(w, shape=(d, i)) - - target_sums[mask_go_left] -= sum_left - left_part = partitioner[mask_go_left][:, :i] - i_right = n - i - right_part = partitioner[mask_go_right][:, i_right:] - - x_next[mask_go_left] = xp.take_along_axis(x[mask_go_left], left_part, axis=1) - x_next[mask_go_right] = xp.take_along_axis(x[mask_go_right], right_part, axis=1) - w_next[mask_go_left] = w_left[mask_go_left] - w_next[mask_go_right] = xp.take_along_axis(w[mask_go_right], right_part, axis=1) - - x = x_next - w = w_next - n = i - - return x[:, 0] - - -def _old_weighted_percentile( - array, sample_weight, percentile_rank=50, average=False, xp=None -): - """Compute the weighted percentile. - - Implement an array API compatible (weighted version) of NumPy's 'inverted_cdf' - method when `average=False` (default) and 'averaged_inverted_cdf' when - `average=True`. - - For an array ordered by increasing values, when the percentile lies exactly on a - data point: - - * 'inverted_cdf' takes the exact data point. - * 'averaged_inverted_cdf' takes the average of the exact data point and the one - above it (this means it gives the same result as `median` for unit weights). - - E.g., for the array [1, 2, 3, 4] the percentile rank at each data point would - be [25, 50, 75, 100]. Percentile rank 50 lies on '2'. 'average_inverted_cdf' - computes the average of '2' and '3', making it 'symmetrical' because if you - reverse the array, rank 50 would fall on '3'. It also matches 'median'. - On the other hand, 'inverted_cdf', which does not satisfy the symmetry property, - would give '2'. - - When the requested percentile lies between two data points, both methods return - the higher data point. - E.g., for the array [1, 2, 3, 4, 5] the percentile rank at each data point would - be [20, 40, 60, 80, 100]. Percentile rank 50, lies between '2' and '3'. Taking the - higher data point is symmetrical because if you reverse the array, 50 would lie - between '4' and '3'. Both methods match median in this case. - - If `array` is a 2D array, the `values` are selected along axis 0. - - `NaN` values are ignored by setting their weights to 0. If `array` is 2D, this - is done in a column-isolated manner: a `NaN` in the second column, does not impact - the percentile computed for the first column even if `sample_weight` is 1D. - - .. versionchanged:: 0.24 - Accepts 2D `array`. - - .. versionchanged:: 1.7 - Supports handling of `NaN` values. - - .. versionchanged:: 1.8 - Supports `average`, which calculates percentile using the - "averaged_inverted_cdf" method. - - Parameters - ---------- - array : 1D or 2D array - Values to take the weighted percentile of. - - sample_weight: 1D or 2D array - Weights for each value in `array`. Must be same shape as `array` or of shape - `(array.shape[0],)`. - - percentile_rank: int or float, default=50 - The probability level of the percentile to compute, in percent. Must be between - 0 and 100. - - average : bool, default=False - If `True`, uses the "averaged_inverted_cdf" quantile method, otherwise - defaults to "inverted_cdf". "averaged_inverted_cdf" is symmetrical with - unit `sample_weight`, such that the total of `sample_weight` below or equal to - `_weighted_percentile(percentile_rank)` is the same as the total of - `sample_weight` above or equal to `_weighted_percentile(100-percentile_rank)`. - This symmetry is not guaranteed with non-unit weights. - xp : array_namespace, default=None - The standard-compatible namespace for `array`. Default: infer. - - Returns - ------- - percentile : scalar or 0D array if `array` 1D (or 0D), array if `array` 2D - Weighted percentile at the requested probability level. - """ - xp, _, device = get_namespace_and_device(array) - # `sample_weight` should follow `array` for dtypes - floating_dtype = _find_matching_floating_dtype(array, xp=xp) - array = xp.asarray(array, dtype=floating_dtype, device=device) - sample_weight = xp.asarray(sample_weight, dtype=floating_dtype, device=device) - - n_dim = array.ndim - if n_dim == 0: - return array - if array.ndim == 1: - array = xp.reshape(array, (-1, 1)) - # When sample_weight 1D, repeat for each array.shape[1] - if array.shape != sample_weight.shape and array.shape[0] == sample_weight.shape[0]: - sample_weight = xp.tile(sample_weight, (array.shape[1], 1)).T # Sort `array` and `sample_weight` along axis=0: sorted_idx = xp.argsort(array, axis=0, stable=False) sorted_weights = xp.take_along_axis(sample_weight, sorted_idx, axis=0) @@ -323,72 +132,81 @@ def _old_weighted_percentile( # `xp.searchsorted` calls take contiguous inputs as a result (for # performance reasons). weight_cdf = xp.cumulative_sum(sorted_weights.T, axis=1) - adjusted_percentile_rank = percentile_rank / 100 * weight_cdf[..., -1] - - # Ignore leading `sample_weight=0` observations when `percentile_rank=0` (#20528) - mask = adjusted_percentile_rank == 0 - adjusted_percentile_rank[mask] = xp.nextafter( - adjusted_percentile_rank[mask], adjusted_percentile_rank[mask] + 1 - ) - # For each feature with index j, find sample index i of the scalar value - # `adjusted_percentile_rank[j]` in 1D array `weight_cdf[j]`, such that: - # weight_cdf[j, i-1] < adjusted_percentile_rank[j] <= weight_cdf[j, i]. - # Note `searchsorted` defaults to equality on the right, whereas Hyndman and Fan - # reference equation has equality on the left. - percentile_indices = xp.stack( - [ - xp.searchsorted( - weight_cdf[feature_idx, ...], adjusted_percentile_rank[feature_idx] - ) - for feature_idx in range(weight_cdf.shape[0]) - ], - ) - # `percentile_indices` may be equal to `sorted_idx.shape[0]` due to floating - # point error (see #11813) - max_idx = sorted_idx.shape[0] - 1 - percentile_indices = xp.clip(percentile_indices, 0, max_idx) - - col_indices = xp.arange(array.shape[1], device=device) - percentile_in_sorted = sorted_idx[percentile_indices, col_indices] - - if average: - # From Hyndman and Fan (1996), `fraction_above` is `g` - fraction_above = ( - weight_cdf[col_indices, percentile_indices] - adjusted_percentile_rank + + n_percentiles = percentile_rank.shape[0] + result = xp.empty((n_features, n_percentiles), dtype=floating_dtype, device=device) + + for p_idx, p_rank in enumerate(percentile_rank): + adjusted_percentile_rank = p_rank / 100 * weight_cdf[..., -1] + + # Ignore leading `sample_weight=0` observations + # when `percentile_rank=0` (#20528) + mask = adjusted_percentile_rank == 0 + adjusted_percentile_rank[mask] = xp.nextafter( + adjusted_percentile_rank[mask], adjusted_percentile_rank[mask] + 1 ) - is_fraction_above = fraction_above > xp.finfo(floating_dtype).eps - percentile_plus_one_indices = xp.clip(percentile_indices + 1, 0, max_idx) - percentile_plus_one_in_sorted = sorted_idx[ - percentile_plus_one_indices, col_indices - ] - # Handle case when next index ('plus one') has sample weight of 0 - zero_weight_cols = col_indices[ - sample_weight[percentile_plus_one_in_sorted, col_indices] == 0 - ] - for col_idx in zero_weight_cols: - cdf_val = weight_cdf[col_idx, percentile_indices[col_idx]] - # Search for next index where `weighted_cdf` is greater - next_index = xp.searchsorted( - weight_cdf[col_idx, ...], cdf_val, side="right" + # For each feature with index j, find sample index i of the scalar value + # `adjusted_percentile_rank[j]` in 1D array `weight_cdf[j]`, such that: + # weight_cdf[j, i-1] < adjusted_percentile_rank[j] <= weight_cdf[j, i]. + # Note `searchsorted` defaults to equality on the right, whereas Hyndman and Fan + # reference equation has equality on the left. + percentile_indices = xp.stack( + [ + xp.searchsorted( + weight_cdf[feature_idx, ...], adjusted_percentile_rank[feature_idx] + ) + for feature_idx in range(weight_cdf.shape[0]) + ], + ) + # `percentile_indices` may be equal to `sorted_idx.shape[0]` due to floating + # point error (see #11813) + max_idx = sorted_idx.shape[0] - 1 + percentile_indices = xp.clip(percentile_indices, 0, max_idx) + + col_indices = xp.arange(array.shape[1], device=device) + percentile_in_sorted = sorted_idx[percentile_indices, col_indices] + + if average: + # From Hyndman and Fan (1996), `fraction_above` is `g` + fraction_above = ( + weight_cdf[col_indices, percentile_indices] - adjusted_percentile_rank ) - # Handle case where there are trailing 0 sample weight samples - # and `percentile_indices` is already max index - if next_index >= max_idx: - # use original `percentile_indices` again - next_index = percentile_indices[col_idx] - - percentile_plus_one_in_sorted[col_idx] = sorted_idx[next_index, col_idx] - - result = xp.where( - is_fraction_above, - array[percentile_in_sorted, col_indices], - ( - array[percentile_in_sorted, col_indices] - + array[percentile_plus_one_in_sorted, col_indices] + is_fraction_above = fraction_above > xp.finfo(floating_dtype).eps + percentile_plus_one_indices = xp.clip(percentile_indices + 1, 0, max_idx) + percentile_plus_one_in_sorted = sorted_idx[ + percentile_plus_one_indices, col_indices + ] + # Handle case when next index ('plus one') has sample weight of 0 + zero_weight_cols = col_indices[ + sample_weight[percentile_plus_one_in_sorted, col_indices] == 0 + ] + for col_idx in zero_weight_cols: + cdf_val = weight_cdf[col_idx, percentile_indices[col_idx]] + # Search for next index where `weighted_cdf` is greater + next_index = xp.searchsorted( + weight_cdf[col_idx, ...], cdf_val, side="right" + ) + # Handle case where there are trailing 0 sample weight samples + # and `percentile_indices` is already max index + if next_index >= max_idx: + # use original `percentile_indices` again + next_index = percentile_indices[col_idx] + + percentile_plus_one_in_sorted[col_idx] = sorted_idx[next_index, col_idx] + + result[..., p_idx] = xp.where( + is_fraction_above, + array[percentile_in_sorted, col_indices], + ( + array[percentile_in_sorted, col_indices] + + array[percentile_plus_one_in_sorted, col_indices] + ) + / 2, ) - / 2, - ) - else: - result = array[percentile_in_sorted, col_indices] + else: + result[..., p_idx] = array[percentile_in_sorted, col_indices] + + if n_dim_percentile == 0: + result = result[..., 0] return result[0] if n_dim == 1 else result From e327b37bf658e2c4e90b4864fe1f21796f9ea567 Mon Sep 17 00:00:00 2001 From: Arthur Date: Sun, 26 Oct 2025 14:19:15 +0100 Subject: [PATCH 19/22] using xpx.quantile everywhere --- sklearn/_loss/loss.py | 26 ++-- sklearn/calibration.py | 2 +- sklearn/dummy.py | 21 +-- sklearn/ensemble/_gb.py | 6 +- .../array_api_extra/_lib/_quantile.py | 2 +- sklearn/metrics/_regression.py | 26 ++-- sklearn/preprocessing/_discretization.py | 30 +--- sklearn/preprocessing/_polynomial.py | 16 +- sklearn/tests/test_dummy.py | 7 +- sklearn/utils/stats.py | 138 ++---------------- 10 files changed, 77 insertions(+), 197 deletions(-) diff --git a/sklearn/_loss/loss.py b/sklearn/_loss/loss.py index 9cbaa5284d3a2..d0aa5a7bc12c1 100644 --- a/sklearn/_loss/loss.py +++ b/sklearn/_loss/loss.py @@ -45,8 +45,8 @@ LogLink, MultinomialLogit, ) +from sklearn.externals import array_api_extra as xpx from sklearn.utils import check_scalar -from sklearn.utils.stats import _weighted_percentile # Note: The shape of raw_prediction for multiclass classifications are @@ -588,7 +588,13 @@ def fit_intercept_only(self, y_true, sample_weight=None): if sample_weight is None: return np.median(y_true, axis=0) else: - return _weighted_percentile(y_true, sample_weight, 50) + return xpx.quantile( + y_true, + 0.5, + axis=0, + weights=sample_weight, + method="averaged_inverted_cdf", + ) class PinballLoss(BaseLoss): @@ -646,12 +652,10 @@ def fit_intercept_only(self, y_true, sample_weight=None): This is the weighted median of the target, i.e. over the samples axis=0. """ - if sample_weight is None: - return np.percentile(y_true, 100 * self.closs.quantile, axis=0) - else: - return _weighted_percentile( - y_true, sample_weight, 100 * self.closs.quantile - ) + method = "linear" if sample_weight is None else "averaged_inverted_cdf" + return xpx.quantile( + y_true, self.closs.quantile, axis=0, method=method, weights=sample_weight + ) class HuberLoss(BaseLoss): @@ -718,10 +722,8 @@ def fit_intercept_only(self, y_true, sample_weight=None): # not to the residual y_true - raw_prediction. An estimator like # HistGradientBoostingRegressor might then call it on the residual, e.g. # fit_intercept_only(y_true - raw_prediction). - if sample_weight is None: - median = np.percentile(y_true, 50, axis=0) - else: - median = _weighted_percentile(y_true, sample_weight, 50) + method = "linear" if sample_weight is None else "averaged_inverted_cdf" + median = xpx.quantile(y_true, 0.5, axis=0, method=method, weights=sample_weight) diff = y_true - median term = np.sign(diff) * np.minimum(self.closs.delta, np.abs(diff)) return median + np.average(term, weights=sample_weight) diff --git a/sklearn/calibration.py b/sklearn/calibration.py index eaadc80cd503a..4ac7913deb958 100644 --- a/sklearn/calibration.py +++ b/sklearn/calibration.py @@ -1294,7 +1294,7 @@ def calibration_curve( if strategy == "quantile": # Determine bin edges by distribution of data quantiles = np.linspace(0, 1, n_bins + 1) - bins = np.percentile(y_prob, quantiles * 100) + bins = xpx.quantile(y_prob, quantiles) elif strategy == "uniform": bins = np.linspace(0.0, 1.0, n_bins + 1) else: diff --git a/sklearn/dummy.py b/sklearn/dummy.py index f0823567abd9e..6eee159c8257c 100644 --- a/sklearn/dummy.py +++ b/sklearn/dummy.py @@ -16,11 +16,11 @@ RegressorMixin, _fit_context, ) +from sklearn.externals import array_api_extra as xpx from sklearn.utils import check_random_state from sklearn.utils._param_validation import Interval, StrOptions from sklearn.utils.multiclass import class_distribution from sklearn.utils.random import _random_choice_csc -from sklearn.utils.stats import _weighted_percentile from sklearn.utils.validation import ( _check_sample_weight, _num_samples, @@ -581,8 +581,12 @@ def fit(self, X, y, sample_weight=None): if sample_weight is None: self.constant_ = np.median(y, axis=0) else: - self.constant_ = _weighted_percentile( - y, sample_weight, percentile_rank=50.0 + self.constant_ = xpx.quantile( + y, + 0.5, + axis=0, + weights=sample_weight, + method="averaged_inverted_cdf", ) elif self.strategy == "quantile": @@ -591,13 +595,10 @@ def fit(self, X, y, sample_weight=None): "When using `strategy='quantile', you have to specify the desired " "quantile in the range [0, 1]." ) - percentile_rank = self.quantile * 100.0 - if sample_weight is None: - self.constant_ = np.percentile(y, axis=0, q=percentile_rank) - else: - self.constant_ = _weighted_percentile( - y, sample_weight, percentile_rank=percentile_rank - ) + method = "linear" if sample_weight is None else "averaged_inverted_cdf" + self.constant_ = xpx.quantile( + y, float(self.quantile), axis=0, weights=sample_weight, method=method + ) elif self.strategy == "constant": if self.constant is None: diff --git a/sklearn/ensemble/_gb.py b/sklearn/ensemble/_gb.py index e64763123f270..771ade8a79d59 100644 --- a/sklearn/ensemble/_gb.py +++ b/sklearn/ensemble/_gb.py @@ -54,7 +54,6 @@ from sklearn.utils import check_array, check_random_state, column_or_1d from sklearn.utils._param_validation import HasMethods, Interval, StrOptions from sklearn.utils.multiclass import check_classification_targets -from sklearn.utils.stats import _weighted_percentile from sklearn.utils.validation import ( _check_sample_weight, check_is_fitted, @@ -275,7 +274,10 @@ def set_huber_delta(loss, y_true, raw_prediction, sample_weight=None): """Calculate and set self.closs.delta based on self.quantile.""" abserr = np.abs(y_true - raw_prediction.squeeze()) # sample_weight is always a ndarray, never None. - delta = _weighted_percentile(abserr, sample_weight, 100 * loss.quantile) + print(abserr, sample_weight) + delta = np.quantile( + abserr, loss.quantile, axis=0, weights=sample_weight, method="inverted_cdf" + ) loss.closs.delta = float(delta) diff --git a/sklearn/externals/array_api_extra/_lib/_quantile.py b/sklearn/externals/array_api_extra/_lib/_quantile.py index 4d50dfd4445d4..df23e94f3f059 100644 --- a/sklearn/externals/array_api_extra/_lib/_quantile.py +++ b/sklearn/externals/array_api_extra/_lib/_quantile.py @@ -170,7 +170,7 @@ def _weighted_quantile_sorted_1d( # numpydoc ignore=GL08 i = xp.clip(i, 0, n - 1) i = xp.take(sorter, i) - q0 = q == 0.0 + q0 = t == 0.0 if average or xp.any(q0): j = xp.searchsorted(cdf, t, side="right") j = xp.clip(j, 0, n - 1) diff --git a/sklearn/metrics/_regression.py b/sklearn/metrics/_regression.py index 955014484fc5d..32ae59bc408e5 100644 --- a/sklearn/metrics/_regression.py +++ b/sklearn/metrics/_regression.py @@ -23,10 +23,10 @@ get_namespace, get_namespace_and_device, size, + xpx, ) from sklearn.utils._array_api import _xlogy as xlogy from sklearn.utils._param_validation import Interval, StrOptions, validate_params -from sklearn.utils.stats import _weighted_percentile from sklearn.utils.validation import ( _check_sample_weight, _num_samples, @@ -923,8 +923,12 @@ def median_absolute_error( if sample_weight is None: output_errors = _median(xp.abs(y_pred - y_true), axis=0) else: - output_errors = _weighted_percentile( - xp.abs(y_pred - y_true), sample_weight=sample_weight, average=True + output_errors = xpx.quantile( + xp.abs(y_pred - y_true), + 0.5, + axis=0, + weights=sample_weight, + method="averaged_inverted_cdf", ) if isinstance(multioutput, str): if multioutput == "raw_values": @@ -1820,17 +1824,11 @@ def d2_pinball_score( multioutput="raw_values", ) - if sample_weight is None: - y_quantile = np.tile( - np.percentile(y_true, q=alpha * 100, axis=0), (len(y_true), 1) - ) - else: - y_quantile = np.tile( - _weighted_percentile( - y_true, sample_weight=sample_weight, percentile_rank=alpha * 100 - ), - (len(y_true), 1), - ) + method = "linear" if sample_weight is None else "averaged_inverted_cdf" + y_quantile = np.tile( + xpx.quantile(y_true, alpha, axis=0, weights=sample_weight, method=method), + (len(y_true), 1), + ) denominator = mean_pinball_loss( y_true, diff --git a/sklearn/preprocessing/_discretization.py b/sklearn/preprocessing/_discretization.py index 847c388599821..a3cd99509cfe3 100644 --- a/sklearn/preprocessing/_discretization.py +++ b/sklearn/preprocessing/_discretization.py @@ -8,10 +8,10 @@ import numpy as np from sklearn.base import BaseEstimator, TransformerMixin, _fit_context +from sklearn.externals import array_api_extra as xpx from sklearn.preprocessing._encoders import OneHotEncoder from sklearn.utils import resample from sklearn.utils._param_validation import Interval, Options, StrOptions -from sklearn.utils.stats import _weighted_percentile from sklearn.utils.validation import ( _check_feature_names_in, _check_sample_weight, @@ -350,27 +350,13 @@ def fit(self, X, y=None, sample_weight=None): bin_edges[jj] = np.linspace(col_min, col_max, n_bins[jj] + 1) elif self.strategy == "quantile": - percentile_levels = np.linspace(0, 100, n_bins[jj] + 1) - - # method="linear" is the implicit default for any numpy - # version. So we keep it version independent in that case by - # using an empty param dict. - percentile_kwargs = {} - if quantile_method != "linear" and sample_weight is None: - percentile_kwargs["method"] = quantile_method - - if sample_weight is None: - bin_edges[jj] = np.asarray( - np.percentile(column, percentile_levels, **percentile_kwargs), - dtype=np.float64, - ) - else: - average = ( - True if quantile_method == "averaged_inverted_cdf" else False - ) - bin_edges[jj] = _weighted_percentile( - column, sample_weight, percentile_levels, average=average - ) + quantile_levels = np.linspace(0, 1, n_bins[jj] + 1) + bin_edges[jj] = xpx.quantile( + column, + quantile_levels, + weights=sample_weight, + method=quantile_method, + ) elif self.strategy == "kmeans": from sklearn.cluster import KMeans # fixes import loops diff --git a/sklearn/preprocessing/_polynomial.py b/sklearn/preprocessing/_polynomial.py index e34b25fbbdd88..6eee6e5d00bb9 100644 --- a/sklearn/preprocessing/_polynomial.py +++ b/sklearn/preprocessing/_polynomial.py @@ -16,6 +16,7 @@ from scipy.special import comb from sklearn.base import BaseEstimator, TransformerMixin, _fit_context +from sklearn.externals import array_api_extra as xpx from sklearn.preprocessing._csr_polynomial_expansion import ( _calc_expanded_nnz, _calc_total_nnz, @@ -30,7 +31,6 @@ from sklearn.utils._mask import _get_mask from sklearn.utils._param_validation import Interval, StrOptions from sklearn.utils.fixes import parse_version, sp_version -from sklearn.utils.stats import _weighted_percentile from sklearn.utils.validation import ( FLOAT_DTYPES, _check_feature_names_in, @@ -784,14 +784,18 @@ def _get_base_knot_positions(X, n_knots=10, knots="uniform", sample_weight=None) Knot positions (points) of base interval. """ if knots == "quantile": - percentile_ranks = 100 * np.linspace( - start=0, stop=1, num=n_knots, dtype=np.float64 - ) + quantile_ranks = np.linspace(start=0, stop=1, num=n_knots, dtype=np.float64) if sample_weight is None: - knots = np.nanpercentile(X, percentile_ranks, axis=0) + knots = np.nanquantile(X, quantile_ranks, axis=0) else: - knots = _weighted_percentile(X, sample_weight, percentile_ranks).T + knots = xpx.quantile( + X, + quantile_ranks, + axis=0, + weights=sample_weight, + method="averaged_inverted_cdf", + ) else: # knots == 'uniform': diff --git a/sklearn/tests/test_dummy.py b/sklearn/tests/test_dummy.py index 61f1803b7a24f..995cf468273ee 100644 --- a/sklearn/tests/test_dummy.py +++ b/sklearn/tests/test_dummy.py @@ -7,13 +7,13 @@ from sklearn.base import clone from sklearn.dummy import DummyClassifier, DummyRegressor from sklearn.exceptions import NotFittedError +from sklearn.externals import array_api_extra as xpx from sklearn.utils._testing import ( assert_almost_equal, assert_array_almost_equal, assert_array_equal, ) from sklearn.utils.fixes import CSC_CONTAINERS -from sklearn.utils.stats import _weighted_percentile def _check_predict_proba(clf, X, y): @@ -631,11 +631,12 @@ def test_dummy_regressor_sample_weight(global_random_seed, n_samples=10): est = DummyRegressor(strategy="mean").fit(X, y, sample_weight) assert est.constant_ == np.average(y, weights=sample_weight) + method = "averaged_inverted_cdf" est = DummyRegressor(strategy="median").fit(X, y, sample_weight) - assert est.constant_ == _weighted_percentile(y, sample_weight, 50.0) + assert est.constant_ == xpx.quantile(y, 0.5, weights=sample_weight, method=method) est = DummyRegressor(strategy="quantile", quantile=0.95).fit(X, y, sample_weight) - assert est.constant_ == _weighted_percentile(y, sample_weight, 95.0) + assert est.constant_ == xpx.quantile(y, 0.95, weights=sample_weight, method=method) def test_dummy_regressor_on_3D_array(): diff --git a/sklearn/utils/stats.py b/sklearn/utils/stats.py index a2827fc503b8d..6c6e62a9e01fb 100644 --- a/sklearn/utils/stats.py +++ b/sklearn/utils/stats.py @@ -1,10 +1,7 @@ # Authors: The scikit-learn developers # SPDX-License-Identifier: BSD-3-Clause -from sklearn.utils._array_api import ( - _find_matching_floating_dtype, - get_namespace_and_device, -) +from sklearn.externals import array_api_extra as xpx def _weighted_percentile( @@ -87,126 +84,15 @@ def _weighted_percentile( If `array` is 1D and `percentile_rank` is 1D, returns a 1D array of shape `(percentile_rank.shape[0],)` If `array` is 2D and `percentile_rank` is 1D, returns a 2D array - of shape `(array.shape[1], percentile_rank.shape[0])` + of shape `(percentile_rank.shape[0], array.shape[1])` """ - xp, _, device = get_namespace_and_device(array) - # `sample_weight` should follow `array` for dtypes - floating_dtype = _find_matching_floating_dtype(array, xp=xp) - array = xp.asarray(array, dtype=floating_dtype, device=device) - sample_weight = xp.asarray(sample_weight, dtype=floating_dtype, device=device) - percentile_rank = xp.asarray(percentile_rank, dtype=floating_dtype, device=device) - - n_dim = array.ndim - if n_dim == 0: - return array - if array.ndim == 1: - array = xp.reshape(array, (-1, 1)) - # When sample_weight 1D, repeat for each array.shape[1] - if array.shape != sample_weight.shape and array.shape[0] == sample_weight.shape[0]: - sample_weight = xp.tile(sample_weight, (array.shape[1], 1)).T - - n_dim_percentile = percentile_rank.ndim - if n_dim_percentile == 0: - percentile_rank = xp.reshape(percentile_rank, (1,)) - - # Sort `array` and `sample_weight` along axis=0: - sorted_idx = xp.argsort(array, axis=0, stable=False) - sorted_weights = xp.take_along_axis(sample_weight, sorted_idx, axis=0) - - # Set NaN values in `sample_weight` to 0. Only perform this operation if NaN - # values present to avoid temporary allocations of size `(n_samples, n_features)`. - n_features = array.shape[1] - largest_value_per_column = array[ - sorted_idx[-1, ...], xp.arange(n_features, device=device) - ] - # NaN values get sorted to end (largest value) - if xp.any(xp.isnan(largest_value_per_column)): - sorted_nan_mask = xp.take_along_axis(xp.isnan(array), sorted_idx, axis=0) - sorted_weights[sorted_nan_mask] = 0 - - # Compute the weighted cumulative distribution function (CDF) based on - # `sample_weight` and scale `percentile_rank` along it. - # - # Note: we call `xp.cumulative_sum` on the transposed `sorted_weights` to - # ensure that the result is of shape `(n_features, n_samples)` so - # `xp.searchsorted` calls take contiguous inputs as a result (for - # performance reasons). - weight_cdf = xp.cumulative_sum(sorted_weights.T, axis=1) - - n_percentiles = percentile_rank.shape[0] - result = xp.empty((n_features, n_percentiles), dtype=floating_dtype, device=device) - - for p_idx, p_rank in enumerate(percentile_rank): - adjusted_percentile_rank = p_rank / 100 * weight_cdf[..., -1] - - # Ignore leading `sample_weight=0` observations - # when `percentile_rank=0` (#20528) - mask = adjusted_percentile_rank == 0 - adjusted_percentile_rank[mask] = xp.nextafter( - adjusted_percentile_rank[mask], adjusted_percentile_rank[mask] + 1 - ) - # For each feature with index j, find sample index i of the scalar value - # `adjusted_percentile_rank[j]` in 1D array `weight_cdf[j]`, such that: - # weight_cdf[j, i-1] < adjusted_percentile_rank[j] <= weight_cdf[j, i]. - # Note `searchsorted` defaults to equality on the right, whereas Hyndman and Fan - # reference equation has equality on the left. - percentile_indices = xp.stack( - [ - xp.searchsorted( - weight_cdf[feature_idx, ...], adjusted_percentile_rank[feature_idx] - ) - for feature_idx in range(weight_cdf.shape[0]) - ], - ) - # `percentile_indices` may be equal to `sorted_idx.shape[0]` due to floating - # point error (see #11813) - max_idx = sorted_idx.shape[0] - 1 - percentile_indices = xp.clip(percentile_indices, 0, max_idx) - - col_indices = xp.arange(array.shape[1], device=device) - percentile_in_sorted = sorted_idx[percentile_indices, col_indices] - - if average: - # From Hyndman and Fan (1996), `fraction_above` is `g` - fraction_above = ( - weight_cdf[col_indices, percentile_indices] - adjusted_percentile_rank - ) - is_fraction_above = fraction_above > xp.finfo(floating_dtype).eps - percentile_plus_one_indices = xp.clip(percentile_indices + 1, 0, max_idx) - percentile_plus_one_in_sorted = sorted_idx[ - percentile_plus_one_indices, col_indices - ] - # Handle case when next index ('plus one') has sample weight of 0 - zero_weight_cols = col_indices[ - sample_weight[percentile_plus_one_in_sorted, col_indices] == 0 - ] - for col_idx in zero_weight_cols: - cdf_val = weight_cdf[col_idx, percentile_indices[col_idx]] - # Search for next index where `weighted_cdf` is greater - next_index = xp.searchsorted( - weight_cdf[col_idx, ...], cdf_val, side="right" - ) - # Handle case where there are trailing 0 sample weight samples - # and `percentile_indices` is already max index - if next_index >= max_idx: - # use original `percentile_indices` again - next_index = percentile_indices[col_idx] - - percentile_plus_one_in_sorted[col_idx] = sorted_idx[next_index, col_idx] - - result[..., p_idx] = xp.where( - is_fraction_above, - array[percentile_in_sorted, col_indices], - ( - array[percentile_in_sorted, col_indices] - + array[percentile_plus_one_in_sorted, col_indices] - ) - / 2, - ) - else: - result[..., p_idx] = array[percentile_in_sorted, col_indices] - - if n_dim_percentile == 0: - result = result[..., 0] - - return result[0] if n_dim == 1 else result + method = "averaged_inverted_cdf" if average else "inverted_cdf" + return xpx.quantile( + array, + percentile_rank / 100, + axis=0, + method=method, + weights=sample_weight, + xp=xp, + nan_policy="omit", + ) From de9b63cd65b95d9fcc9143b433db8f63a2bfca69 Mon Sep 17 00:00:00 2001 From: Arthur Date: Sun, 26 Oct 2025 14:39:54 +0100 Subject: [PATCH 20/22] wip --- sklearn/ensemble/_gb.py | 3 ++- sklearn/ensemble/tests/test_gradient_boosting.py | 1 + 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/sklearn/ensemble/_gb.py b/sklearn/ensemble/_gb.py index 771ade8a79d59..49dbdd619fa3f 100644 --- a/sklearn/ensemble/_gb.py +++ b/sklearn/ensemble/_gb.py @@ -274,10 +274,11 @@ def set_huber_delta(loss, y_true, raw_prediction, sample_weight=None): """Calculate and set self.closs.delta based on self.quantile.""" abserr = np.abs(y_true - raw_prediction.squeeze()) # sample_weight is always a ndarray, never None. - print(abserr, sample_weight) + print((np.sort(abserr) * 10000).round(4)) delta = np.quantile( abserr, loss.quantile, axis=0, weights=sample_weight, method="inverted_cdf" ) + print(delta) loss.closs.delta = float(delta) diff --git a/sklearn/ensemble/tests/test_gradient_boosting.py b/sklearn/ensemble/tests/test_gradient_boosting.py index 20866348697f6..8dc18dce792c2 100644 --- a/sklearn/ensemble/tests/test_gradient_boosting.py +++ b/sklearn/ensemble/tests/test_gradient_boosting.py @@ -1565,6 +1565,7 @@ def test_huber_exact_backward_compat(): gbt = GradientBoostingRegressor(loss="huber", n_estimators=100, alpha=0.8).fit(X, y) assert_allclose(gbt._loss.closs.delta, 0.0001655688041282133) + return pred_result = np.array( [ From 89e650367cd5126c4de8740e150ed630ec498824 Mon Sep 17 00:00:00 2001 From: Arthur Date: Sun, 26 Oct 2025 14:45:02 +0100 Subject: [PATCH 21/22] wip --- sklearn/ensemble/_gb.py | 1 + 1 file changed, 1 insertion(+) diff --git a/sklearn/ensemble/_gb.py b/sklearn/ensemble/_gb.py index 49dbdd619fa3f..52c2588b1cf7b 100644 --- a/sklearn/ensemble/_gb.py +++ b/sklearn/ensemble/_gb.py @@ -274,6 +274,7 @@ def set_huber_delta(loss, y_true, raw_prediction, sample_weight=None): """Calculate and set self.closs.delta based on self.quantile.""" abserr = np.abs(y_true - raw_prediction.squeeze()) # sample_weight is always a ndarray, never None. + assert (sample_weight == 1.).all() print((np.sort(abserr) * 10000).round(4)) delta = np.quantile( abserr, loss.quantile, axis=0, weights=sample_weight, method="inverted_cdf" From 202430caacb8057cc0c75c8c1831767fb6953ac4 Mon Sep 17 00:00:00 2001 From: Arthur Date: Sun, 26 Oct 2025 14:56:19 +0100 Subject: [PATCH 22/22] fixed backward compat test --- sklearn/_loss/loss.py | 9 ++++++++- sklearn/ensemble/_gb.py | 8 ++++---- sklearn/ensemble/tests/test_gradient_boosting.py | 1 - 3 files changed, 12 insertions(+), 6 deletions(-) diff --git a/sklearn/_loss/loss.py b/sklearn/_loss/loss.py index d0aa5a7bc12c1..c8269fc46d13a 100644 --- a/sklearn/_loss/loss.py +++ b/sklearn/_loss/loss.py @@ -722,7 +722,14 @@ def fit_intercept_only(self, y_true, sample_weight=None): # not to the residual y_true - raw_prediction. An estimator like # HistGradientBoostingRegressor might then call it on the residual, e.g. # fit_intercept_only(y_true - raw_prediction). - method = "linear" if sample_weight is None else "averaged_inverted_cdf" + + method = "linear" if sample_weight is None else "inverted_cdf" + # XXX: it would be better to use method "averaged_inverted_cdf" + # for the weighted case + # (otherwise passing 1s weights is not equivalent to no weights) + # but this would break this test: + # ensemble/tests/test_gradient_boosting.py::test_huber_exact_backward_compat + median = xpx.quantile(y_true, 0.5, axis=0, method=method, weights=sample_weight) diff = y_true - median term = np.sign(diff) * np.minimum(self.closs.delta, np.abs(diff)) diff --git a/sklearn/ensemble/_gb.py b/sklearn/ensemble/_gb.py index 52c2588b1cf7b..d48a51d1cdf3a 100644 --- a/sklearn/ensemble/_gb.py +++ b/sklearn/ensemble/_gb.py @@ -47,6 +47,7 @@ predict_stages, ) from sklearn.exceptions import NotFittedError +from sklearn.externals import array_api_extra as xpx from sklearn.model_selection import train_test_split from sklearn.preprocessing import LabelEncoder from sklearn.tree import DecisionTreeRegressor @@ -274,12 +275,11 @@ def set_huber_delta(loss, y_true, raw_prediction, sample_weight=None): """Calculate and set self.closs.delta based on self.quantile.""" abserr = np.abs(y_true - raw_prediction.squeeze()) # sample_weight is always a ndarray, never None. - assert (sample_weight == 1.).all() - print((np.sort(abserr) * 10000).round(4)) - delta = np.quantile( + delta = xpx.quantile( abserr, loss.quantile, axis=0, weights=sample_weight, method="inverted_cdf" ) - print(delta) + # XXX: it would probably be better to use method "averaged_inverted_cdf" + # see explanations of why we can't in HuberLoss.fit_intercept_only loss.closs.delta = float(delta) diff --git a/sklearn/ensemble/tests/test_gradient_boosting.py b/sklearn/ensemble/tests/test_gradient_boosting.py index 8dc18dce792c2..20866348697f6 100644 --- a/sklearn/ensemble/tests/test_gradient_boosting.py +++ b/sklearn/ensemble/tests/test_gradient_boosting.py @@ -1565,7 +1565,6 @@ def test_huber_exact_backward_compat(): gbt = GradientBoostingRegressor(loss="huber", n_estimators=100, alpha=0.8).fit(X, y) assert_allclose(gbt._loss.closs.delta, 0.0001655688041282133) - return pred_result = np.array( [