Skip to content

[API Proposal]: Generic overloads of existing TensorPrimitives methods #94553

@stephentoub

Description

@stephentoub

Background and motivation

For .NET 8, we added the new TensorPrimitives type, with methods dedicated to handling float. For post-.NET 8, we're planning to augment this in three ways (#93286):

  1. Generic versions of these methods in order to handle other numerical types beyond float
  2. Additional generic methods for all the operations on the generic math interfaces that aren't currently on TensorPrimitives
  3. Additional generic methods to have good coverage of further relevant operations ala BLAS / LAPACK

This issue covers (1).

API Proposal

Exactly the same signatures as on TensorPrimitives in .NET 8, with an overload that takes a T instead of float.

namespace System.Numerics.Tensors;

public static class TensorPrimitives
{
    public static void Abs<T>(ReadOnlySpan<T> x, Span<T> destination) where T : INumberBase<T>;
    public static void AddMultiply<T>(ReadOnlySpan<T> x, ReadOnlySpan<T> y, ReadOnlySpan<T> multiplier, Span<T> destination) where T : IAdditionOperators<T, T, T>, IMultiplyOperators<T, T, T>;
    public static void AddMultiply<T>(ReadOnlySpan<T> x, ReadOnlySpan<T> y, T multiplier, Span<T> destination) where T : IAdditionOperators<T, T, T>, IMultiplyOperators<T, T, T>;
    public static void AddMultiply<T>(ReadOnlySpan<T> x, T y, ReadOnlySpan<T> multiplier, Span<T> destination) where T : IAdditionOperators<T, T, T>, IMultiplyOperators<T, T, T>;
    public static void Add<T>(ReadOnlySpan<T> x, ReadOnlySpan<T> y, Span<T> destination) where T : IAdditionOperators<T, T, T>, IAdditiveIdentity<T, T>;
    public static void Add<T>(ReadOnlySpan<T> x, T y, Span<T> destination)  where T : IAdditionOperators<T, T, T>, IAdditiveIdentity<T, T>;
    public static void Cosh<T>(ReadOnlySpan<T> x, Span<T> destination) where T : IHyperbolicFunctions<T>;
    public static T CosineSimilarity<T>(ReadOnlySpan<T> x, ReadOnlySpan<T> y) where T : IRootFunctions<T>;
    public static T Distance<T>(ReadOnlySpan<T> x, ReadOnlySpan<T> y) where T : INumberBase<T>, IRootFunctions<T>;
    public static void Divide<T>(ReadOnlySpan<T> x, ReadOnlySpan<T> y, Span<T> destination) where T : IDivisionOperators<T, T, T>;
    public static void Divide<T>(ReadOnlySpan<T> x, T y, Span<T> destination) where T : IDivisionOperators<T, T, T>;
    public static T Dot<T>(ReadOnlySpan<T> x, ReadOnlySpan<T> y) where T : IAdditionOperators<T, T, T>, IAdditiveIdentity<T, T>, IMultiplyOperators<T, T, T>, IMultiplicativeIdentity<T, T>;
    public static void Exp<T>(ReadOnlySpan<T> x, Span<T> destination) where T : IExponentialFunctions<T>;
    public static int IndexOfMax<T>(ReadOnlySpan<T> x) where T : INumber<T>;
    public static int IndexOfMaxMagnitude<T>(ReadOnlySpan<T> x) where T : INumber<T>;
    public static int IndexOfMin<T>(ReadOnlySpan<T> x) where T : INumber<T>;
    public static int IndexOfMinMagnitude<T>(ReadOnlySpan<T> x) where T : INumber<T>;
    public static void Log2<T>(ReadOnlySpan<T> x, Span<T> destination) where T : ILogarithmicFunctions<T>;
    public static void Log<T>(ReadOnlySpan<T> x, Span<T> destination) where T : ILogarithmicFunctions<T>;
    public static T MaxMagnitude<T>(ReadOnlySpan<T> x) where T : INumberBase<T>;
    public static void MaxMagnitude<T>(ReadOnlySpan<T> x, ReadOnlySpan<T> y, Span<T> destination) where T : INumberBase<T>;
    public static T Max<T>(ReadOnlySpan<T> x) where T : INumber<T>;
    public static void Max<T>(ReadOnlySpan<T> x, ReadOnlySpan<T> y, Span<T> destination) where T : INumber<T>;
    public static T MinMagnitude<T>(ReadOnlySpan<T> x) where T : INumberBase<T>;
    public static void MinMagnitude<T>(ReadOnlySpan<T> x, ReadOnlySpan<T> y, Span<T> destination) where T : INumberBase<T>;
    public static T Min<T>(ReadOnlySpan<T> x) where T : INumber<T>;
    public static void Min<T>(ReadOnlySpan<T> x, ReadOnlySpan<T> y, Span<T> destination) where T : INumber<T>;
    public static void MultiplyAdd<T>(ReadOnlySpan<T> x, ReadOnlySpan<T> y, ReadOnlySpan<T> addend, Span<T> destination) where T : IAdditionOperators<T, T, T>, IMultiplyOperators<T, T, T>;
    public static void MultiplyAdd<T>(ReadOnlySpan<T> x, ReadOnlySpan<T> y, T addend, Span<T> destination) where T : IAdditionOperators<T, T, T>, IMultiplyOperators<T, T, T>;
    public static void MultiplyAdd<T>(ReadOnlySpan<T> x, T y, ReadOnlySpan<T> addend, Span<T> destination) where T : IAdditionOperators<T, T, T>, IMultiplyOperators<T, T, T>;
    public static void Multiply<T>(ReadOnlySpan<T> x, ReadOnlySpan<T> y, Span<T> destination) where T : INumberBase<T>;
    public static void Multiply<T>(ReadOnlySpan<T> x, T y, Span<T> destination) where T : IMultiplyOperators<T, T, T>, IMultiplicativeIdentity<T, T>;
    public static void Negate<T>(ReadOnlySpan<T> x, Span<T> destination) where T : IUnaryNegationOperators<T, T>;
    public static T Norm<T>(ReadOnlySpan<T> x) where T : IRootFunctions<T>;
    public static T ProductOfDifferences<T>(ReadOnlySpan<T> x, ReadOnlySpan<T> y) where T : ISubtractionOperators<T, T, T>, IMultiplyOperators<T, T, T>, IMultiplicativeIdentity<T, T>;
    public static T ProductOfSums<T>(ReadOnlySpan<T> x, ReadOnlySpan<T> y) where T : IAdditionOperators<T, T, T>, IAdditiveIdentity<T, T>, IMultiplyOperators<T, T, T>, IMultiplicativeIdentity<T, T>;
    public static T Product<T>(ReadOnlySpan<T> x) where T : INumberBase<T>;
    public static void Sigmoid<T>(ReadOnlySpan<T> x, Span<T> destination) where T : IExponentialFunction<T>;
    public static void Sinh<T>(ReadOnlySpan<T> x, Span<T> destination) where T : IHyperbolicFunction<T>;
    public static void SoftMax<T>(ReadOnlySpan<T> x, Span<T> destination) where T : IExpoentialFunction<T>;
    public static void Subtract<T>(ReadOnlySpan<T> x, ReadOnlySpan<T> y, Span<T> destination) where T : ISubtractionOperators<T, T, T>;
    public static void Subtract<T>(ReadOnlySpan<T> x, T y, Span<T> destination) where T : ISubtractionOperators<T, T, T>;
    public static T SumOfMagnitudes<T>(ReadOnlySpan<T> x) where T : INumberBase<T>;
    public static T SumOfSquares<T>(ReadOnlySpan<T> x) where T : IAdditionOperators<T, T, T>, IAdditiveIdentity<T, T>, IMultiplyOperators<T, T, T>;
    public static T Sum<T>(ReadOnlySpan<T> x) where T : IAdditionOperators<T, T, T>, IAdditiveIdentity<T, T>;
    public static void Tanh<T>(ReadOnlySpan<T> x, Span<T> destination) where T : IHyperbolicFunctions<T>;
}
  • Each method is constrained to what I believe is the smallest viable set of interfaces required (I think... we should double check). @tannergooding, is this the right thing to do? Or should we have everything either constrain to T to INumber<T> or IFloatingPointIeee754<T> for simplicity / consistency / future flexibility?

API Usage

double[] values1 = ..., values2 = ...;
double similarity = TensorPrimitives.CosineSimilarity(values1, values2);

Alternative Designs

No response

Risks

No response

Metadata

Metadata

Assignees

Labels

Type

No type

Projects

No projects

Milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions