Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,13 @@ categories = ["science", "algorithms", "mathematics"]
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[dependencies]
ndarray = "0.15"
ndarray = "0.16"
num-traits = "0.2"
thiserror = "1.0"

[dev-dependencies]
cargo-tarpaulin = "0.27"
ndarray = {version = "0.15", features = ["approx-0_5", "rayon"] }
ndarray = {version = "0.16", features = ["approx", "rayon"] }
approx = "0.5"
criterion = "0.5"
rand = "0.8"
Expand Down
6 changes: 3 additions & 3 deletions benches/bench_interp1d.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ fn bench_interp1d_scalar_multithread(c: &mut Criterion) {
})
});

let query = query.into_shape((2500, 4)).unwrap();
let query = query.into_shape_with_order((2500, 4)).unwrap();
let query_arr: Vec<_> = query.axis_iter(Axis(0)).collect();
c.bench_function("1D scalar MT `interp_array`", |b| {
b.iter(|| {
Expand All @@ -80,7 +80,7 @@ fn bench_interp1d_scalar_multithread(c: &mut Criterion) {

fn bench_interp1d_array(c: &mut Criterion) {
let data = Array::from_rand(500, (0.0, 1.0), 69)
.into_shape((100, 5))
.into_shape_with_order((100, 5))
.unwrap();
let interp = Interp1D::builder(data).build().unwrap();
let query = Array::from_rand(10_000, (0.0, 99.0), 123);
Expand All @@ -102,7 +102,7 @@ fn bench_interp1d_array(c: &mut Criterion) {
})
});

let query = query.into_shape((2500, 4)).unwrap();
let query = query.into_shape_with_order((2500, 4)).unwrap();
let query_arr: Vec<_> = query.axis_iter(Axis(0)).collect();
c.bench_function("1D array `interp_array`", |b| {
b.iter(|| {
Expand Down
6 changes: 3 additions & 3 deletions benches/bench_interp1d_query_dim.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ fn bench_scalar_data_1d_query(c: &mut Criterion) {
let interp = Interp1D::builder(data).build().unwrap();
let query = Array::from_rand(10_000, (0.0, 99.0), 123);

let query = query.into_shape((2500, 4)).unwrap();
let query = query.into_shape_with_order((2500, 4)).unwrap();
let query_arr: Vec<_> = query.axis_iter(Axis(0)).collect();
c.bench_function("1D scalar `interp_array`", |b| {
b.iter(|| {
Expand All @@ -38,7 +38,7 @@ fn bench_scalar_data_2d_query(c: &mut Criterion) {
let interp = Interp1D::builder(data).build().unwrap();
let query = Array::from_rand(10_000, (0.0, 99.0), 123);

let query = query.into_shape((625, 4, 4)).unwrap();
let query = query.into_shape_with_order((625, 4, 4)).unwrap();
let query_arr: Vec<_> = query.axis_iter(Axis(0)).collect();
c.bench_function("1D scalar `interp_array` 2D-query", |b| {
b.iter(|| {
Expand All @@ -63,7 +63,7 @@ fn bench_scalar_data_3d_query(c: &mut Criterion) {
let interp = Interp1D::builder(data).build().unwrap();
let query = Array::from_rand(10_000, (0.0, 99.0), 123);

let query = query.into_shape((125, 5, 4, 4)).unwrap();
let query = query.into_shape_with_order((125, 5, 4, 4)).unwrap();
let query_arr: Vec<_> = query.axis_iter(Axis(0)).collect();
c.bench_function("1D scalar `interp_array` 3D-query", |b| {
b.iter(|| {
Expand Down
14 changes: 7 additions & 7 deletions benches/bench_interp2d.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ mod rand_extensions;

fn bench_interp2d_scalar(c: &mut Criterion) {
let data = Array::from_rand(10_000, (0.0, 1.0), 42)
.into_shape((100, 100))
.into_shape_with_order((100, 100))
.unwrap();
let interp = Interp2D::builder(data).build().unwrap();
let query_x = Array::from_rand(10_000, (0.0, 99.0), 123);
Expand Down Expand Up @@ -45,7 +45,7 @@ fn bench_interp2d_scalar(c: &mut Criterion) {

fn bench_interp2d_scalar_multithread(c: &mut Criterion) {
let data = Array::from_rand(10_000, (0.0, 1.0), 42)
.into_shape((100, 100))
.into_shape_with_order((100, 100))
.unwrap();
let interp = Interp2D::builder(data).build().unwrap();
let query_x = Array::from_rand(10_000, (0.0, 99.0), 123);
Expand All @@ -67,8 +67,8 @@ fn bench_interp2d_scalar_multithread(c: &mut Criterion) {
})
});

let query_x = query_x.into_shape((2500, 4)).unwrap();
let query_y = query_y.into_shape((2500, 4)).unwrap();
let query_x = query_x.into_shape_with_order((2500, 4)).unwrap();
let query_y = query_y.into_shape_with_order((2500, 4)).unwrap();
let query_arrx: Vec<_> = query_x.axis_iter(Axis(0)).collect();
let query_arry: Vec<_> = query_y.axis_iter(Axis(0)).collect();
c.bench_function("2D scalar MT `interp_array`", |b| {
Expand All @@ -85,7 +85,7 @@ fn bench_interp2d_scalar_multithread(c: &mut Criterion) {

fn bench_interp2d_array(c: &mut Criterion) {
let data = Array::from_rand(50_000, (0.0, 1.0), 69)
.into_shape((100, 100, 5))
.into_shape_with_order((100, 100, 5))
.unwrap();
let interp = Interp2D::builder(data).build().unwrap();
let query_x = Array::from_rand(10_000, (0.0, 99.0), 123);
Expand All @@ -108,8 +108,8 @@ fn bench_interp2d_array(c: &mut Criterion) {
})
});

let query_x = query_x.into_shape((2500, 4)).unwrap();
let query_y = query_y.into_shape((2500, 4)).unwrap();
let query_x = query_x.into_shape_with_order((2500, 4)).unwrap();
let query_y = query_y.into_shape_with_order((2500, 4)).unwrap();
let query_arrx: Vec<_> = query_x.axis_iter(Axis(0)).collect();
let query_arry: Vec<_> = query_y.axis_iter(Axis(0)).collect();
c.bench_function("2D array `interp_array`", |b| {
Expand Down
14 changes: 7 additions & 7 deletions benches/bench_interp2d_query_dim.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ mod rand_extensions;

fn setup() -> (Interp2DScalar<f64, Bilinear>, Array1<f64>, Array1<f64>) {
let data = Array::from_rand(10_000, (0.0, 1.0), 42)
.into_shape((100, 100))
.into_shape_with_order((100, 100))
.unwrap();
let interp = Interp2D::builder(data).build().unwrap();
let query_x = Array::from_rand(10_000, (0.0, 99.0), 123);
Expand All @@ -18,8 +18,8 @@ fn setup() -> (Interp2DScalar<f64, Bilinear>, Array1<f64>, Array1<f64>) {

fn bench_scalar_data_1d_query(c: &mut Criterion) {
let (interp, query_x, query_y) = black_box(setup());
let query_x = query_x.into_shape((2500, 4)).unwrap();
let query_y = query_y.into_shape((2500, 4)).unwrap();
let query_x = query_x.into_shape_with_order((2500, 4)).unwrap();
let query_y = query_y.into_shape_with_order((2500, 4)).unwrap();
let query_arrx: Vec<_> = query_x.axis_iter(Axis(0)).collect();
let query_arry: Vec<_> = query_y.axis_iter(Axis(0)).collect();

Expand All @@ -43,8 +43,8 @@ fn bench_scalar_data_1d_query(c: &mut Criterion) {

fn bench_scalar_data_2d_query(c: &mut Criterion) {
let (interp, query_x, query_y) = black_box(setup());
let query_x = query_x.into_shape((625, 4, 4)).unwrap();
let query_y = query_y.into_shape((625, 4, 4)).unwrap();
let query_x = query_x.into_shape_with_order((625, 4, 4)).unwrap();
let query_y = query_y.into_shape_with_order((625, 4, 4)).unwrap();
let query_arrx: Vec<_> = query_x.axis_iter(Axis(0)).collect();
let query_arry: Vec<_> = query_y.axis_iter(Axis(0)).collect();

Expand All @@ -68,8 +68,8 @@ fn bench_scalar_data_2d_query(c: &mut Criterion) {

fn bench_scalar_data_3d_query(c: &mut Criterion) {
let (interp, query_x, query_y) = black_box(setup());
let query_x = query_x.into_shape((125, 5, 4, 4)).unwrap();
let query_y = query_y.into_shape((125, 5, 4, 4)).unwrap();
let query_x = query_x.into_shape_with_order((125, 5, 4, 4)).unwrap();
let query_y = query_y.into_shape_with_order((125, 5, 4, 4)).unwrap();
let query_arrx: Vec<_> = query_x.axis_iter(Axis(0)).collect();
let query_arry: Vec<_> = query_y.axis_iter(Axis(0)).collect();

Expand Down
33 changes: 19 additions & 14 deletions src/interp1d/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -308,14 +308,15 @@ where
}
});

let subview = match subview.into_shape(self.data.raw_dim().remove_axis(Axis(0))) {
Ok(view) => view,
Err(err) => {
let expect = self.get_buffer_shape(xs.raw_dim()).into_pattern();
let got = buffer.dim();
panic!("{err} expected: {expect:?}, got: {got:?}")
}
};
let subview =
match subview.into_shape_with_order(self.data.raw_dim().remove_axis(Axis(0))) {
Ok(view) => view,
Err(err) => {
let expect = self.get_buffer_shape(xs.raw_dim()).into_pattern();
let got = buffer.dim();
panic!("{err} expected: {expect:?}, got: {got:?}")
}
};

self.strategy.interp_into(self, subview, x)?;
}
Expand Down Expand Up @@ -359,7 +360,11 @@ where
/// - `x` is stricktly monotonic rising
/// - `data.shape()[0] == x.len()`
/// - the `strategy` is porperly initialized with the data
pub unsafe fn new_unchecked(x: ArrayBase<Sx, Ix1>, data: ArrayBase<Sd, D>, strategy: Strat) -> Self {
pub unsafe fn new_unchecked(
x: ArrayBase<Sx, Ix1>,
data: ArrayBase<Sd, D>,
strategy: Strat,
) -> Self {
Interp1D { x, data, strategy }
}

Expand Down Expand Up @@ -498,7 +503,7 @@ mod tests {
macro_rules! get_interp {
($dim:expr, $shape:expr) => {{
let arr = rand_arr(4usize.pow($dim), (0.0, 1.0), 64)
.into_shape($shape)
.into_shape_with_order($shape)
.unwrap();
Interp1D::builder(arr).build().unwrap()
}};
Expand Down Expand Up @@ -565,7 +570,7 @@ mod tests {
#[should_panic(expected = "expected: [2], got: [1]")] // this is not really a good message
fn interp1d_2d_array_into_too_small1() {
let arr = rand_arr((4usize).pow(2), (0.0, 1.0), 64)
.into_shape((4, 4))
.into_shape_with_order((4, 4))
.unwrap();
let interp = Interp1D::builder(arr).build().unwrap();
let mut buf = Array::zeros((1, 4));
Expand All @@ -576,7 +581,7 @@ mod tests {
#[should_panic]
fn interp1d_2d_array_into_too_small2() {
let arr = rand_arr((4usize).pow(2), (0.0, 1.0), 64)
.into_shape((4, 4))
.into_shape_with_order((4, 4))
.unwrap();
let interp = Interp1D::builder(arr).build().unwrap();
let mut buf = Array::zeros((2, 3));
Expand All @@ -587,7 +592,7 @@ mod tests {
#[should_panic]
fn interp1d_2d_array_into_too_big1() {
let arr = rand_arr((4usize).pow(2), (0.0, 1.0), 64)
.into_shape((4, 4))
.into_shape_with_order((4, 4))
.unwrap();
let interp = Interp1D::builder(arr).build().unwrap();
let mut buf = Array::zeros((3, 4));
Expand All @@ -598,7 +603,7 @@ mod tests {
#[should_panic]
fn interp1d_2d_array_into_too_big2() {
let arr = rand_arr((4usize).pow(2), (0.0, 1.0), 64)
.into_shape((4, 4))
.into_shape_with_order((4, 4))
.unwrap();
let interp = Interp1D::builder(arr).build().unwrap();
let mut buf = Array::zeros((2, 5));
Expand Down
6 changes: 3 additions & 3 deletions src/interp2d/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@ where
}
});

let subview = match subview.into_shape(
let subview = match subview.into_shape_with_order(
self.data
.raw_dim()
.remove_axis(Axis(0))
Expand Down Expand Up @@ -543,7 +543,7 @@ mod tests {
#[test]
fn $name() {
let arr = rand_arr(4usize.pow($dim), (0.0, 1.0), 64)
.into_shape($shape)
.into_shape_with_order($shape)
.unwrap();
let interp = Interp2D::builder(arr).build().unwrap();
let res = interp.interp(2.2, 2.2).unwrap();
Expand Down Expand Up @@ -578,7 +578,7 @@ mod tests {
#[test]
fn interp2d_2d_scalar() {
let arr = rand_arr(4usize.pow(2), (0.0, 1.0), 64)
.into_shape((4, 4))
.into_shape_with_order((4, 4))
.unwrap();
let _res: f64 = Interp2D::builder(arr) // typecheck f64 as return type
.build()
Expand Down
14 changes: 9 additions & 5 deletions tests/interp2d.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,17 +83,19 @@ fn extrapolate() {

#[test]
fn interpolate_array() {
let data = Array::linspace(0.0, 8.0, 9).into_shape((3, 3)).unwrap();
let data = Array::linspace(0.0, 8.0, 9)
.into_shape_with_order((3, 3))
.unwrap();
let x = array![1.0, 2.0, 3.0];
let y = array![4.0, 5.0, 6.0];
let resolution = 11usize;
let qx = Array::linspace(1.0, 3.0, resolution);
let qy = Array::linspace(4.0, 6.0, resolution);
let qx = Array::from_iter(qx.into_iter().flat_map(|x| repeat(x).take(resolution)))
.into_shape((resolution, resolution))
.into_shape_with_order((resolution, resolution))
.unwrap();
let qy = Array::from_iter(repeat(qy).take(resolution).flatten())
.into_shape((resolution, resolution))
.into_shape_with_order((resolution, resolution))
.unwrap();

let interp = Interp2D::builder(data).x(x).y(y).build().unwrap();
Expand Down Expand Up @@ -247,7 +249,7 @@ fn interp_nd_data() {
.flatten()
.flatten(),
)
.into_shape((2, 2, 2, 2))
.into_shape_with_order((2, 2, 2, 2))
.unwrap();

let interp = Interp2DBuilder::new(data).build().unwrap();
Expand All @@ -265,7 +267,9 @@ fn interp_nd_data() {
#[test]
#[should_panic(expected = "`xs.shape()` and `ys.shape()` do not match")]
fn interp_array_with_unmatched_axis() {
let data = Array::linspace(0.0, 8.0, 9).into_shape((3, 3)).unwrap();
let data = Array::linspace(0.0, 8.0, 9)
.into_shape_with_order((3, 3))
.unwrap();
let qx = array![0.0, 1.0];
let qy = array![0.0, 1.0, 2.0];
let interp = Interp2D::builder(data).build().unwrap();
Expand Down
Loading