Skip to content

Commit 1e75260

Browse files
committed
Ensure the relevant interfaces are implemented on the TensorSpan and ReadOnlyTensorSpan types
1 parent da9d4b0 commit 1e75260

File tree

11 files changed

+450
-666
lines changed

11 files changed

+450
-666
lines changed

src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/netcore/IReadOnlyTensor.cs

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -56,9 +56,5 @@ public interface IReadOnlyTensor
5656
/// <summary>Gets the stride of each dimension in the tensor.</summary>
5757
[UnscopedRef]
5858
ReadOnlySpan<nint> Strides { get; }
59-
60-
/// <summary>Pins and gets a <see cref="MemoryHandle"/> to the backing memory.</summary>
61-
/// <returns><see cref="MemoryHandle"/></returns>
62-
MemoryHandle GetPinnedHandle();
6359
}
6460
}

src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/netcore/IReadOnlyTensor_1.cs

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,18 @@
22
// The .NET Foundation licenses this file to you under the MIT license.
33

44
using System.Buffers;
5-
using System.Collections.Generic;
6-
using System.Diagnostics.CodeAnalysis;
75

86
namespace System.Numerics.Tensors
97
{
108
/// <summary>Represents a read-only tensor.</summary>
119
/// <typeparam name="TSelf">The type that implements this interface.</typeparam>
1210
/// <typeparam name="T">The element type.</typeparam>
13-
public interface IReadOnlyTensor<TSelf, T> : IReadOnlyTensor, IEnumerable<T>
11+
public interface IReadOnlyTensor<TSelf, T> : IReadOnlyTensor
12+
#if NET9_0_OR_GREATER
13+
where TSelf : IReadOnlyTensor<TSelf, T>, allows ref struct
14+
#else
1415
where TSelf : IReadOnlyTensor<TSelf, T>
16+
#endif
1517
{
1618
/// <summary>Gets an empty tensor.</summary>
1719
static abstract TSelf Empty { get; }

src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/netcore/ITensor_1.cs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,18 @@
22
// The .NET Foundation licenses this file to you under the MIT license.
33

44
using System.Buffers;
5-
using System.Diagnostics.CodeAnalysis;
65

76
namespace System.Numerics.Tensors
87
{
98
/// <summary>Represents a tensor.</summary>
109
/// <typeparam name="TSelf">The type that implements this interface.</typeparam>
1110
/// <typeparam name="T">The element type.</typeparam>
1211
public interface ITensor<TSelf, T> : ITensor, IReadOnlyTensor<TSelf, T>
12+
#if NET9_0_OR_GREATER
13+
where TSelf : ITensor<TSelf, T>, allows ref struct
14+
#else
1315
where TSelf : ITensor<TSelf, T>
16+
#endif
1417
{
1518
// TODO: Determine if we can implement `IEqualityOperators<TSelf, T, bool>`.
1619
// It looks like C#/.NET currently hits limitations here as it believes TSelf and T could be the same type

src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/netcore/ReadOnlyTensorDimensionSpan_1.cs

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
// Licensed to the .NET Foundation under one or more agreements.
22
// The .NET Foundation licenses this file to you under the MIT license.
33

4+
using System.Runtime.CompilerServices;
5+
6+
#if NET9_0_OR_GREATER
47
using System.Collections;
58
using System.Collections.Generic;
6-
using System.Runtime.CompilerServices;
7-
using System.Runtime.InteropServices;
9+
#endif
810

911
namespace System.Numerics.Tensors
1012
{
@@ -28,9 +30,12 @@ internal ReadOnlyTensorDimensionSpan(ReadOnlyTensorSpan<T> tensor, int dimension
2830
_tensor = tensor;
2931
_length = TensorPrimitives.Product(tensor.Lengths[..dimension]);
3032
_dimension = dimension;
31-
_sliceShape = TensorShape.Create((dimension != tensor.Rank) ? tensor.Lengths[dimension..] : [1], tensor.Strides[dimension..]);
33+
_sliceShape = TensorShape.Create((dimension != tensor.Rank) ? tensor.Lengths[dimension..] : [1], tensor.Strides[dimension..], tensor.IsPinned);
3234
}
3335

36+
/// <summary>Gets <c>true</c> if the slices that exist within the tracked dimension are dense; otherwise, <c>false</c>.</summary>
37+
public bool IsDense => _sliceShape.IsDense;
38+
3439
/// <summary>Gets the length of the tensor dimension span.</summary>
3540
public nint Length => _length;
3641

src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/netcore/ReadOnlyTensorSpan_1.cs

Lines changed: 51 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,9 @@ namespace System.Numerics.Tensors
2222
[DebuggerTypeProxy(typeof(TensorSpanDebugView<>))]
2323
[DebuggerDisplay("{ToString(),raw}")]
2424
public readonly ref struct ReadOnlyTensorSpan<T>
25+
#if NET9_0_OR_GREATER
26+
: IReadOnlyTensor<ReadOnlyTensorSpan<T>, T>
27+
#endif
2528
{
2629
/// <inheritdoc cref="IReadOnlyTensor{TSelf, T}.Empty" />
2730
public static ReadOnlyTensorSpan<T> Empty => default;
@@ -54,12 +57,7 @@ public ReadOnlyTensorSpan(T[]? array)
5457
/// * <paramref name="lengths" /> is not empty and has a flattened length greater than <paramref name="array" />.Length.
5558
/// </exception>
5659
public ReadOnlyTensorSpan(T[]? array, scoped ReadOnlySpan<nint> lengths)
57-
{
58-
_shape = TensorShape.Create(array, lengths);
59-
_reference = ref (array is not null)
60-
? ref MemoryMarshal.GetArrayDataReference(array)
61-
: ref Unsafe.NullRef<T>();
62-
}
60+
: this(array, lengths, strides: []) { }
6361

6462
/// <summary>Creates a new tensor over the portion of the target array beginning at the specified start index and using the specified lengths and strides.</summary>
6563
/// <param name="array">The target array.</param>
@@ -113,7 +111,7 @@ public ReadOnlyTensorSpan(T[]? array, int start, scoped ReadOnlySpan<nint> lengt
113111
public ReadOnlyTensorSpan(ReadOnlySpan<T> span)
114112
{
115113
ref T reference = ref MemoryMarshal.GetReference(span);
116-
_shape = TensorShape.Create(ref reference, span.Length);
114+
_shape = TensorShape.Create(ref reference, span.Length, pinned: false);
117115
_reference = ref reference;
118116
}
119117

@@ -126,11 +124,7 @@ public ReadOnlyTensorSpan(ReadOnlySpan<T> span)
126124
/// * <paramref name="lengths" /> is not empty and has a flattened length greater than <paramref name="span" />.Length.
127125
/// </exception>
128126
public ReadOnlyTensorSpan(ReadOnlySpan<T> span, scoped ReadOnlySpan<nint> lengths)
129-
{
130-
ref T reference = ref MemoryMarshal.GetReference(span);
131-
_shape = TensorShape.Create(ref reference, span.Length, lengths);
132-
_reference = ref reference;
133-
}
127+
: this(span, lengths, strides: []) { }
134128

135129
/// <summary>Creates a new tensor span over the target span using the specified lengths and strides.</summary>
136130
/// <param name="span">The target span.</param>
@@ -147,7 +141,7 @@ public ReadOnlyTensorSpan(ReadOnlySpan<T> span, scoped ReadOnlySpan<nint> length
147141
public ReadOnlyTensorSpan(ReadOnlySpan<T> span, scoped ReadOnlySpan<nint> lengths, scoped ReadOnlySpan<nint> strides)
148142
{
149143
ref T reference = ref MemoryMarshal.GetReference(span);
150-
_shape = TensorShape.Create(ref reference, span.Length, lengths, strides);
144+
_shape = TensorShape.Create(ref reference, span.Length, lengths, strides, pinned: false);
151145
_reference = ref reference;
152146
}
153147

@@ -219,10 +213,7 @@ public unsafe ReadOnlyTensorSpan(T* data, nint dataLength)
219213
/// </exception>
220214
[CLSCompliant(false)]
221215
public unsafe ReadOnlyTensorSpan(T* data, nint dataLength, scoped ReadOnlySpan<nint> lengths)
222-
{
223-
_shape = TensorShape.Create(data, dataLength, lengths);
224-
_reference = ref Unsafe.AsRef<T>(data);
225-
}
216+
: this(data, dataLength, lengths, strides: []) { }
226217

227218
/// <summary>Creates a new tensor span over the target unmanaged buffer using the specified lengths and strides.</summary>
228219
/// <param name="data">The pointer to the start of the target unmanaged buffer.</param>
@@ -247,28 +238,9 @@ public unsafe ReadOnlyTensorSpan(T* data, nint dataLength, scoped ReadOnlySpan<n
247238
_reference = ref Unsafe.AsRef<T>(data);
248239
}
249240

250-
// Constructor for internal use only. It is not safe to expose publicly.
251-
internal ReadOnlyTensorSpan(ref T data, nint dataLength)
252-
{
253-
_shape = TensorShape.Create(ref data, dataLength);
254-
_reference = ref data;
255-
}
256-
257-
internal ReadOnlyTensorSpan(ref T data, nint dataLength, scoped ReadOnlySpan<nint> lengths)
258-
{
259-
_shape = TensorShape.Create(ref data, dataLength, lengths);
260-
_reference = ref data;
261-
}
262-
263-
internal ReadOnlyTensorSpan(ref T data, nint dataLength, scoped ReadOnlySpan<nint> lengths, scoped ReadOnlySpan<nint> strides)
241+
internal ReadOnlyTensorSpan(ref T data, nint dataLength, scoped ReadOnlySpan<nint> lengths, scoped ReadOnlySpan<nint> strides, bool pinned)
264242
{
265-
_shape = TensorShape.Create(ref data, dataLength, lengths, strides);
266-
_reference = ref data;
267-
}
268-
269-
internal ReadOnlyTensorSpan(ref T data, nint dataLength, scoped ReadOnlySpan<nint> lengths, scoped ReadOnlySpan<nint> strides, scoped ReadOnlySpan<int> linearRankOrder)
270-
{
271-
_shape = TensorShape.Create(ref data, dataLength, lengths, strides, linearRankOrder);
243+
_shape = TensorShape.Create(ref data, dataLength, lengths, strides, pinned);
272244
_reference = ref data;
273245
}
274246

@@ -308,6 +280,9 @@ public ReadOnlyTensorSpan<T> this[params scoped ReadOnlySpan<NRange> ranges]
308280
/// <inheritdoc cref="IReadOnlyTensor.IsEmpty" />
309281
public bool IsEmpty => _shape.IsEmpty;
310282

283+
/// <inheritdoc cref="IReadOnlyTensor.IsPinned" />
284+
public bool IsPinned => _shape.IsPinned;
285+
311286
/// <inheritdoc cref="IReadOnlyTensor.Lengths" />
312287
[UnscopedRef]
313288
public ReadOnlySpan<nint> Lengths => _shape.Lengths;
@@ -463,11 +438,47 @@ public bool TryFlattenTo(scoped Span<T> destination)
463438
return false;
464439
}
465440

441+
#if NET9_0_OR_GREATER
442+
//
443+
// IReadOnlyTensor
444+
//
445+
446+
object? IReadOnlyTensor.this[params scoped ReadOnlySpan<NIndex> indexes] => this[indexes];
447+
448+
object? IReadOnlyTensor.this[params scoped ReadOnlySpan<nint> indexes] => this[indexes];
449+
450+
//
451+
// IReadOnlyTensor<TSelf, T>
452+
//
453+
454+
ReadOnlyTensorSpan<T> IReadOnlyTensor<ReadOnlyTensorSpan<T>, T>.AsReadOnlyTensorSpan() => this;
455+
456+
ReadOnlyTensorSpan<T> IReadOnlyTensor<ReadOnlyTensorSpan<T>, T>.AsReadOnlyTensorSpan(params scoped ReadOnlySpan<nint> startIndexes) => Slice(startIndexes);
457+
458+
ReadOnlyTensorSpan<T> IReadOnlyTensor<ReadOnlyTensorSpan<T>, T>.AsReadOnlyTensorSpan(params scoped ReadOnlySpan<NIndex> startIndexes) => Slice(startIndexes);
459+
460+
ReadOnlyTensorSpan<T> IReadOnlyTensor<ReadOnlyTensorSpan<T>, T>.AsReadOnlyTensorSpan(params scoped ReadOnlySpan<NRange> ranges) => Slice(ranges);
461+
462+
ReadOnlyTensorSpan<T> IReadOnlyTensor<ReadOnlyTensorSpan<T>, T>.ToDenseTensor()
463+
{
464+
ReadOnlyTensorSpan<T> result = this;
465+
466+
if (!IsDense)
467+
{
468+
Tensor<T> tmp = Tensor.Create<T>(Lengths, IsPinned);
469+
CopyTo(tmp);
470+
result = tmp;
471+
}
472+
473+
return result;
474+
}
475+
#endif
476+
466477
/// <summary>Enumerates the elements of a tensor span.</summary>
467478
public ref struct Enumerator : IEnumerator<T>
468479
{
469480
private readonly ReadOnlyTensorSpan<T> _span;
470-
private nint[] _indexes;
481+
private readonly nint[] _indexes;
471482
private nint _linearOffset;
472483
private nint _itemsEnumerated;
473484

@@ -513,7 +524,7 @@ public void Reset()
513524
// IDisposable
514525
//
515526

516-
void IDisposable.Dispose() { }
527+
readonly void IDisposable.Dispose() { }
517528

518529
//
519530
// IEnumerator

0 commit comments

Comments
 (0)