Skip to content

Commit a3b3537

Browse files
authored
Makes the reference types sound via a DST (#1532)
The current implementation of ArrayRef and its cousins has them as sized types, which turns out to be a critical and unsound mistake. This PR is large, but its heart is small: change the ArrayRef implementation to be unsized. The approach this PR takes is to package the core array "metadata" - the pointer, dim, and strides - into a struct that can either be sized or unsized. This is done by appending a generic "_dst_control" field. For the "sized" version of the metadata, that field is a 0-length array. For the "unsized" version of the metadata, that sized field is a struct. This core type is private, so users cannot construct any other variants other than these two. We then put the sized version into the ArrayBase types, put the unsized version into the reference types, and perform an unsizing coercion to convert from one to the other. Because Rust has no (safe, supported) "resizing" coercion, this switch is irreversible. Sized types cannot be recovered from the unsized reference types.
1 parent e6bf804 commit a3b3537

29 files changed

+457
-334
lines changed

examples/functions_and_traits.rs

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ fn takes_rawref_mut<A, D>(arr: &mut RawRef<A, D>)
139139
/// Immutable, take a generic that implements `AsRef` to `RawRef`
140140
#[allow(dead_code)]
141141
fn takes_rawref_asref<T, A, D>(_arr: &T)
142-
where T: AsRef<RawRef<A, D>>
142+
where T: AsRef<RawRef<A, D>> + ?Sized
143143
{
144144
takes_layout(_arr.as_ref());
145145
takes_layout_asref(_arr.as_ref());
@@ -148,7 +148,7 @@ where T: AsRef<RawRef<A, D>>
148148
/// Mutable, take a generic that implements `AsMut` to `RawRef`
149149
#[allow(dead_code)]
150150
fn takes_rawref_asmut<T, A, D>(_arr: &mut T)
151-
where T: AsMut<RawRef<A, D>>
151+
where T: AsMut<RawRef<A, D>> + ?Sized
152152
{
153153
takes_layout_mut(_arr.as_mut());
154154
takes_layout_asmut(_arr.as_mut());
@@ -169,10 +169,16 @@ fn takes_layout_mut<A, D>(_arr: &mut LayoutRef<A, D>) {}
169169

170170
/// Immutable, take a generic that implements `AsRef` to `LayoutRef`
171171
#[allow(dead_code)]
172-
fn takes_layout_asref<T: AsRef<LayoutRef<A, D>>, A, D>(_arr: &T) {}
172+
fn takes_layout_asref<T, A, D>(_arr: &T)
173+
where T: AsRef<LayoutRef<A, D>> + ?Sized
174+
{
175+
}
173176

174177
/// Mutable, take a generic that implements `AsMut` to `LayoutRef`
175178
#[allow(dead_code)]
176-
fn takes_layout_asmut<T: AsMut<LayoutRef<A, D>>, A, D>(_arr: &mut T) {}
179+
fn takes_layout_asmut<T, A, D>(_arr: &mut T)
180+
where T: AsMut<LayoutRef<A, D>> + ?Sized
181+
{
182+
}
177183

178184
fn main() {}

scripts/all-tests.sh

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,5 +30,8 @@ fi
3030
# Examples
3131
cargo nextest run --examples
3232

33+
# Doc tests
34+
cargo test --doc
35+
3336
# Benchmarks
3437
([ "$CHANNEL" != "nightly" ] || cargo bench --no-run --verbose --features "$FEATURES")

src/arraytraits.rs

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@ use std::{iter::FromIterator, slice};
1919
use crate::imp_prelude::*;
2020
use crate::Arc;
2121

22-
use crate::LayoutRef;
2322
use crate::{
2423
dimension,
2524
iter::{Iter, IterMut},
@@ -38,12 +37,14 @@ pub(crate) fn array_out_of_bounds() -> !
3837
}
3938

4039
#[inline(always)]
41-
pub fn debug_bounds_check<A, D, I>(_a: &LayoutRef<A, D>, _index: &I)
40+
pub fn debug_bounds_check<A, D, I, T>(_a: &T, _index: &I)
4241
where
4342
D: Dimension,
4443
I: NdIndex<D>,
44+
T: AsRef<LayoutRef<A, D>> + ?Sized,
4545
{
46-
debug_bounds_check!(_a, *_index);
46+
let _layout_ref = _a.as_ref();
47+
debug_bounds_check_ref!(_layout_ref, *_index);
4748
}
4849

4950
/// Access the element at **index**.
@@ -59,11 +60,11 @@ where
5960
#[inline]
6061
fn index(&self, index: I) -> &Self::Output
6162
{
62-
debug_bounds_check!(self, index);
63+
debug_bounds_check_ref!(self, index);
6364
unsafe {
64-
&*self.ptr.as_ptr().offset(
65+
&*self._ptr().as_ptr().offset(
6566
index
66-
.index_checked(&self.dim, &self.strides)
67+
.index_checked(self._dim(), self._strides())
6768
.unwrap_or_else(|| array_out_of_bounds()),
6869
)
6970
}
@@ -81,11 +82,11 @@ where
8182
#[inline]
8283
fn index_mut(&mut self, index: I) -> &mut A
8384
{
84-
debug_bounds_check!(self, index);
85+
debug_bounds_check_ref!(self, index);
8586
unsafe {
8687
&mut *self.as_mut_ptr().offset(
8788
index
88-
.index_checked(&self.dim, &self.strides)
89+
.index_checked(self._dim(), self._strides())
8990
.unwrap_or_else(|| array_out_of_bounds()),
9091
)
9192
}
@@ -581,7 +582,7 @@ where D: Dimension
581582
{
582583
let data = OwnedArcRepr(Arc::new(arr.data));
583584
// safe because: equivalent unmoved data, ptr and dims remain valid
584-
unsafe { ArrayBase::from_data_ptr(data, arr.layout.ptr).with_strides_dim(arr.layout.strides, arr.layout.dim) }
585+
unsafe { ArrayBase::from_data_ptr(data, arr.parts.ptr).with_strides_dim(arr.parts.strides, arr.parts.dim) }
585586
}
586587
}
587588

src/data_traits.rs

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -251,7 +251,7 @@ where A: Clone
251251
if Arc::get_mut(&mut self_.data.0).is_some() {
252252
return;
253253
}
254-
if self_.layout.dim.size() <= self_.data.0.len() / 2 {
254+
if self_.parts.dim.size() <= self_.data.0.len() / 2 {
255255
// Clone only the visible elements if the current view is less than
256256
// half of backing data.
257257
*self_ = self_.to_owned().into_shared();
@@ -260,13 +260,13 @@ where A: Clone
260260
let rcvec = &mut self_.data.0;
261261
let a_size = mem::size_of::<A>() as isize;
262262
let our_off = if a_size != 0 {
263-
(self_.layout.ptr.as_ptr() as isize - rcvec.as_ptr() as isize) / a_size
263+
(self_.parts.ptr.as_ptr() as isize - rcvec.as_ptr() as isize) / a_size
264264
} else {
265265
0
266266
};
267267
let rvec = Arc::make_mut(rcvec);
268268
unsafe {
269-
self_.layout.ptr = rvec.as_nonnull_mut().offset(our_off);
269+
self_.parts.ptr = rvec.as_nonnull_mut().offset(our_off);
270270
}
271271
}
272272

@@ -287,7 +287,7 @@ unsafe impl<A> Data for OwnedArcRepr<A>
287287
let data = Arc::try_unwrap(self_.data.0).ok().unwrap();
288288
// safe because data is equivalent
289289
unsafe {
290-
ArrayBase::from_data_ptr(data, self_.layout.ptr).with_strides_dim(self_.layout.strides, self_.layout.dim)
290+
ArrayBase::from_data_ptr(data, self_.parts.ptr).with_strides_dim(self_.parts.strides, self_.parts.dim)
291291
}
292292
}
293293

@@ -297,14 +297,14 @@ unsafe impl<A> Data for OwnedArcRepr<A>
297297
match Arc::try_unwrap(self_.data.0) {
298298
Ok(owned_data) => unsafe {
299299
// Safe because the data is equivalent.
300-
Ok(ArrayBase::from_data_ptr(owned_data, self_.layout.ptr)
301-
.with_strides_dim(self_.layout.strides, self_.layout.dim))
300+
Ok(ArrayBase::from_data_ptr(owned_data, self_.parts.ptr)
301+
.with_strides_dim(self_.parts.strides, self_.parts.dim))
302302
},
303303
Err(arc_data) => unsafe {
304304
// Safe because the data is equivalent; we're just
305305
// reconstructing `self_`.
306-
Err(ArrayBase::from_data_ptr(OwnedArcRepr(arc_data), self_.layout.ptr)
307-
.with_strides_dim(self_.layout.strides, self_.layout.dim))
306+
Err(ArrayBase::from_data_ptr(OwnedArcRepr(arc_data), self_.parts.ptr)
307+
.with_strides_dim(self_.parts.strides, self_.parts.dim))
308308
},
309309
}
310310
}
@@ -603,9 +603,9 @@ where A: Clone
603603
CowRepr::View(_) => {
604604
let owned = ArrayRef::to_owned(array);
605605
array.data = CowRepr::Owned(owned.data);
606-
array.layout.ptr = owned.layout.ptr;
607-
array.layout.dim = owned.layout.dim;
608-
array.layout.strides = owned.layout.strides;
606+
array.parts.ptr = owned.parts.ptr;
607+
array.parts.dim = owned.parts.dim;
608+
array.parts.strides = owned.parts.strides;
609609
}
610610
CowRepr::Owned(_) => {}
611611
}
@@ -666,8 +666,7 @@ unsafe impl<'a, A> Data for CowRepr<'a, A>
666666
CowRepr::View(_) => self_.to_owned(),
667667
CowRepr::Owned(data) => unsafe {
668668
// safe because the data is equivalent so ptr, dims remain valid
669-
ArrayBase::from_data_ptr(data, self_.layout.ptr)
670-
.with_strides_dim(self_.layout.strides, self_.layout.dim)
669+
ArrayBase::from_data_ptr(data, self_.parts.ptr).with_strides_dim(self_.parts.strides, self_.parts.dim)
671670
},
672671
}
673672
}
@@ -679,8 +678,8 @@ unsafe impl<'a, A> Data for CowRepr<'a, A>
679678
CowRepr::View(_) => Err(self_),
680679
CowRepr::Owned(data) => unsafe {
681680
// safe because the data is equivalent so ptr, dims remain valid
682-
Ok(ArrayBase::from_data_ptr(data, self_.layout.ptr)
683-
.with_strides_dim(self_.layout.strides, self_.layout.dim))
681+
Ok(ArrayBase::from_data_ptr(data, self_.parts.ptr)
682+
.with_strides_dim(self_.parts.strides, self_.parts.dim))
684683
},
685684
}
686685
}

src/free_functions.rs

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ use std::mem::{forget, size_of};
1616
use std::ptr::NonNull;
1717

1818
use crate::{dimension, ArcArray1, ArcArray2};
19-
use crate::{imp_prelude::*, LayoutRef};
19+
use crate::{imp_prelude::*, ArrayPartsSized};
2020

2121
/// Create an **[`Array`]** with one, two, three, four, five, or six dimensions.
2222
///
@@ -109,12 +109,12 @@ pub const fn aview0<A>(x: &A) -> ArrayView0<'_, A>
109109
{
110110
ArrayBase {
111111
data: ViewRepr::new(),
112-
layout: LayoutRef {
112+
parts: ArrayPartsSized::new(
113113
// Safe because references are always non-null.
114-
ptr: unsafe { NonNull::new_unchecked(x as *const A as *mut A) },
115-
dim: Ix0(),
116-
strides: Ix0(),
117-
},
114+
unsafe { NonNull::new_unchecked(x as *const A as *mut A) },
115+
Ix0(),
116+
Ix0(),
117+
),
118118
}
119119
}
120120

@@ -149,12 +149,12 @@ pub const fn aview1<A>(xs: &[A]) -> ArrayView1<'_, A>
149149
}
150150
ArrayBase {
151151
data: ViewRepr::new(),
152-
layout: LayoutRef {
152+
parts: ArrayPartsSized::new(
153153
// Safe because references are always non-null.
154-
ptr: unsafe { NonNull::new_unchecked(xs.as_ptr() as *mut A) },
155-
dim: Ix1(xs.len()),
156-
strides: Ix1(1),
157-
},
154+
unsafe { NonNull::new_unchecked(xs.as_ptr() as *mut A) },
155+
Ix1(xs.len()),
156+
Ix1(1),
157+
),
158158
}
159159
}
160160

@@ -207,7 +207,7 @@ pub const fn aview2<A, const N: usize>(xs: &[[A; N]]) -> ArrayView2<'_, A>
207207
};
208208
ArrayBase {
209209
data: ViewRepr::new(),
210-
layout: LayoutRef { ptr, dim, strides },
210+
parts: ArrayPartsSized::new(ptr, dim, strides),
211211
}
212212
}
213213

src/impl_clone.rs

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
// except according to those terms.
88

99
use crate::imp_prelude::*;
10-
use crate::LayoutRef;
10+
use crate::ArrayPartsSized;
1111
use crate::RawDataClone;
1212

1313
impl<S: RawDataClone, D: Clone> Clone for ArrayBase<S, D>
@@ -16,14 +16,10 @@ impl<S: RawDataClone, D: Clone> Clone for ArrayBase<S, D>
1616
{
1717
// safe because `clone_with_ptr` promises to provide equivalent data and ptr
1818
unsafe {
19-
let (data, ptr) = self.data.clone_with_ptr(self.layout.ptr);
19+
let (data, ptr) = self.data.clone_with_ptr(self.parts.ptr);
2020
ArrayBase {
2121
data,
22-
layout: LayoutRef {
23-
ptr,
24-
dim: self.layout.dim.clone(),
25-
strides: self.layout.strides.clone(),
26-
},
22+
parts: ArrayPartsSized::new(ptr, self.parts.dim.clone(), self.parts.strides.clone()),
2723
}
2824
}
2925
}
@@ -34,9 +30,9 @@ impl<S: RawDataClone, D: Clone> Clone for ArrayBase<S, D>
3430
fn clone_from(&mut self, other: &Self)
3531
{
3632
unsafe {
37-
self.layout.ptr = self.data.clone_from_with_ptr(&other.data, other.layout.ptr);
38-
self.layout.dim.clone_from(&other.layout.dim);
39-
self.layout.strides.clone_from(&other.layout.strides);
33+
self.parts.ptr = self.data.clone_from_with_ptr(&other.data, other.parts.ptr);
34+
self.parts.dim.clone_from(&other.parts.dim);
35+
self.parts.strides.clone_from(&other.parts.strides);
4036
}
4137
}
4238
}

src/impl_cow.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,8 @@ where D: Dimension
3434
{
3535
// safe because equivalent data
3636
unsafe {
37-
ArrayBase::from_data_ptr(CowRepr::View(view.data), view.ptr)
38-
.with_strides_dim(view.layout.strides, view.layout.dim)
37+
ArrayBase::from_data_ptr(CowRepr::View(view.data), view.parts.ptr)
38+
.with_strides_dim(view.parts.strides, view.parts.dim)
3939
}
4040
}
4141
}
@@ -47,8 +47,8 @@ where D: Dimension
4747
{
4848
// safe because equivalent data
4949
unsafe {
50-
ArrayBase::from_data_ptr(CowRepr::Owned(array.data), array.layout.ptr)
51-
.with_strides_dim(array.layout.strides, array.layout.dim)
50+
ArrayBase::from_data_ptr(CowRepr::Owned(array.data), array.parts.ptr)
51+
.with_strides_dim(array.parts.strides, array.parts.dim)
5252
}
5353
}
5454
}

src/impl_dyn.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,8 @@ impl<A> LayoutRef<A, IxDyn>
3131
pub fn insert_axis_inplace(&mut self, axis: Axis)
3232
{
3333
assert!(axis.index() <= self.ndim());
34-
self.dim = self.dim.insert_axis(axis);
35-
self.strides = self.strides.insert_axis(axis);
34+
self.0.dim = self._dim().insert_axis(axis);
35+
self.0.strides = self._strides().insert_axis(axis);
3636
}
3737

3838
/// Collapses the array to `index` along the axis and removes the axis,
@@ -54,8 +54,8 @@ impl<A> LayoutRef<A, IxDyn>
5454
pub fn index_axis_inplace(&mut self, axis: Axis, index: usize)
5555
{
5656
self.collapse_axis(axis, index);
57-
self.dim = self.dim.remove_axis(axis);
58-
self.strides = self.strides.remove_axis(axis);
57+
self.0.dim = self._dim().remove_axis(axis);
58+
self.0.strides = self._strides().remove_axis(axis);
5959
}
6060
}
6161

src/impl_internal_constructors.rs

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
use std::ptr::NonNull;
1010

11-
use crate::{imp_prelude::*, LayoutRef};
11+
use crate::{imp_prelude::*, ArrayPartsSized};
1212

1313
// internal "builder-like" methods
1414
impl<A, S> ArrayBase<S, Ix1>
@@ -27,11 +27,7 @@ where S: RawData<Elem = A>
2727
{
2828
let array = ArrayBase {
2929
data,
30-
layout: LayoutRef {
31-
ptr,
32-
dim: Ix1(0),
33-
strides: Ix1(1),
34-
},
30+
parts: ArrayPartsSized::new(ptr, Ix1(0), Ix1(1)),
3531
};
3632
debug_assert!(array.pointer_is_inbounds());
3733
array
@@ -60,11 +56,7 @@ where
6056
debug_assert_eq!(strides.ndim(), dim.ndim());
6157
ArrayBase {
6258
data: self.data,
63-
layout: LayoutRef {
64-
ptr: self.layout.ptr,
65-
dim,
66-
strides,
67-
},
59+
parts: ArrayPartsSized::new(self.parts.ptr, dim, strides),
6860
}
6961
}
7062
}

0 commit comments

Comments
 (0)