Skip to content

Commit 5dab92a

Browse files
committed
Extend ndarray-rand to be able to randomly sample from ArrayRef.
Prior to `ndarray` 0.17, the `RandomExt` trait exposed by `ndarray-rand` contained methods for both creating new arrays randomly whole-cloth (`random_using`) and sampling from existing arrays (`sample_axis_using`). With the introduction of reference types in `ndarray` 0.17, users should be able to sample from `ArrayRef` instances as well. We choose to expose an additional extension trait, `RandomRefExt`, that provides this functionality. We keep the methods on the old trait for backwards compatibility, but collapse the implementation and documentation to the new trait to maintain a single source of truth.
1 parent 66dc0e1 commit 5dab92a

File tree

1 file changed

+65
-6
lines changed

1 file changed

+65
-6
lines changed

ndarray-rand/src/lib.rs

Lines changed: 65 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -29,12 +29,14 @@
2929
//! that the items are not compatible (e.g. that a type doesn't implement a
3030
//! necessary trait).
3131
32+
#![warn(missing_docs)]
33+
3234
use crate::rand::distr::{Distribution, Uniform};
3335
use crate::rand::rngs::SmallRng;
3436
use crate::rand::seq::index;
3537
use crate::rand::{rng, Rng, SeedableRng};
3638

37-
use ndarray::{Array, Axis, RemoveAxis, ShapeBuilder};
39+
use ndarray::{Array, ArrayRef, Axis, RemoveAxis, ShapeBuilder};
3840
use ndarray::{ArrayBase, Data, DataOwned, Dimension, RawData};
3941
#[cfg(feature = "quickcheck")]
4042
use quickcheck::{Arbitrary, Gen};
@@ -124,6 +126,43 @@ where
124126
S: DataOwned<Elem = A>,
125127
Sh: ShapeBuilder<Dim = D>;
126128

129+
/// Sample `n_samples` lanes slicing along `axis` using the default RNG.
130+
///
131+
/// See [`RandomRefExt::sample_axis`] for additional information.
132+
fn sample_axis(&self, axis: Axis, n_samples: usize, strategy: SamplingStrategy) -> Array<A, D>
133+
where
134+
A: Copy,
135+
S: Data<Elem = A>,
136+
D: RemoveAxis;
137+
138+
/// Sample `n_samples` lanes slicing along `axis` using the specified RNG `rng`.
139+
///
140+
/// See [`RandomRefExt::sample_axis_using`] for additional information.
141+
fn sample_axis_using<R>(
142+
&self, axis: Axis, n_samples: usize, strategy: SamplingStrategy, rng: &mut R,
143+
) -> Array<A, D>
144+
where
145+
R: Rng + ?Sized,
146+
A: Copy,
147+
S: Data<Elem = A>,
148+
D: RemoveAxis;
149+
}
150+
151+
/// Constructors for sampling from [`ArrayRef`] with random elements.
152+
///
153+
/// This trait extends ndarray’s `ArrayRef` and can not be implemented
154+
/// for other types.
155+
///
156+
/// The default RNG is a fast automatically seeded rng (currently
157+
/// [`rand::rngs::SmallRng`], seeded from [`rand::thread_rng`]).
158+
///
159+
/// Note that `SmallRng` is cheap to initialize and fast, but it may generate
160+
/// low-quality random numbers, and reproducibility is not guaranteed. See its
161+
/// documentation for information. You can select a different RNG with
162+
/// [`.random_using()`](Self::random_using).
163+
pub trait RandomRefExt<A, D>
164+
where D: Dimension
165+
{
127166
/// Sample `n_samples` lanes slicing along `axis` using the default RNG.
128167
///
129168
/// If `strategy==SamplingStrategy::WithoutReplacement`, each lane can only be sampled once.
@@ -168,7 +207,6 @@ where
168207
fn sample_axis(&self, axis: Axis, n_samples: usize, strategy: SamplingStrategy) -> Array<A, D>
169208
where
170209
A: Copy,
171-
S: Data<Elem = A>,
172210
D: RemoveAxis;
173211

174212
/// Sample `n_samples` lanes slicing along `axis` using the specified RNG `rng`.
@@ -225,7 +263,6 @@ where
225263
where
226264
R: Rng + ?Sized,
227265
A: Copy,
228-
S: Data<Elem = A>,
229266
D: RemoveAxis;
230267
}
231268

@@ -259,7 +296,7 @@ where
259296
S: Data<Elem = A>,
260297
D: RemoveAxis,
261298
{
262-
self.sample_axis_using(axis, n_samples, strategy, &mut get_rng())
299+
(**self).sample_axis(axis, n_samples, strategy)
263300
}
264301

265302
fn sample_axis_using<R>(&self, axis: Axis, n_samples: usize, strategy: SamplingStrategy, rng: &mut R) -> Array<A, D>
@@ -268,6 +305,27 @@ where
268305
A: Copy,
269306
S: Data<Elem = A>,
270307
D: RemoveAxis,
308+
{
309+
(&**self).sample_axis_using(axis, n_samples, strategy, rng)
310+
}
311+
}
312+
313+
impl<A, D> RandomRefExt<A, D> for ArrayRef<A, D>
314+
where D: Dimension
315+
{
316+
fn sample_axis(&self, axis: Axis, n_samples: usize, strategy: SamplingStrategy) -> Array<A, D>
317+
where
318+
A: Copy,
319+
D: RemoveAxis,
320+
{
321+
self.sample_axis_using(axis, n_samples, strategy, &mut get_rng())
322+
}
323+
324+
fn sample_axis_using<R>(&self, axis: Axis, n_samples: usize, strategy: SamplingStrategy, rng: &mut R) -> Array<A, D>
325+
where
326+
R: Rng + ?Sized,
327+
A: Copy,
328+
D: RemoveAxis,
271329
{
272330
let indices: Vec<_> = match strategy {
273331
SamplingStrategy::WithReplacement => {
@@ -284,9 +342,10 @@ where
284342
/// if lanes from the original array should only be sampled once (*without replacement*) or
285343
/// multiple times (*with replacement*).
286344
///
287-
/// [`sample_axis`]: RandomExt::sample_axis
288-
/// [`sample_axis_using`]: RandomExt::sample_axis_using
345+
/// [`sample_axis`]: RandomRefExt::sample_axis
346+
/// [`sample_axis_using`]: RandomRefExt::sample_axis_using
289347
#[derive(Debug, Clone)]
348+
#[allow(missing_docs)]
290349
pub enum SamplingStrategy
291350
{
292351
WithReplacement,

0 commit comments

Comments
 (0)