Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion docs/OperatorKernels.md
Original file line number Diff line number Diff line change
Expand Up @@ -449,7 +449,9 @@ Do not modify directly.*
|||[7, 12]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)|
|Affine|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
|And|*in* A:**T**<br> *in* B:**T**<br> *out* C:**T1**|7+|**T** = tensor(bool)<br/> **T1** = tensor(bool)|
|ArgMax|*in* data:**T**<br> *out* reduced:**tensor(int64)**|11+|**T** = tensor(double), tensor(float), tensor(float16)|
|ArgMax|*in* data:**T**<br> *out* reduced:**tensor(int64)**|13+|**T** = tensor(double), tensor(float), tensor(float16)|
|||12|**T** = tensor(double), tensor(float), tensor(float16)|
|||11|**T** = tensor(double), tensor(float), tensor(float16)|
|||[1, 10]|**T** = tensor(double), tensor(float), tensor(float16)|
|ArgMin|*in* data:**T**<br> *out* reduced:**tensor(int64)**|11+|**T** = tensor(double), tensor(float), tensor(float16)|
|||[1, 10]|**T** = tensor(double), tensor(float), tensor(float16)|
Expand Down
53 changes: 47 additions & 6 deletions onnxruntime/core/providers/cuda/cuda_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -743,9 +743,9 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain,
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 10, 10, TopK);

// opset 11
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, float, ArgMax);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, double, ArgMax);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, MLFloat16, ArgMax);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 11, float, ArgMax);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 11, double, ArgMax);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 11, MLFloat16, ArgMax);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, float, ArgMin);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, double, ArgMin);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, MLFloat16, ArgMin);
Expand Down Expand Up @@ -898,6 +898,9 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kO

class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, 12, Dropout);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, Einsum);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, 12, float, ArgMax);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, 12, double, ArgMax);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, 12, MLFloat16, ArgMax);

// OpSet 13
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 14, Pow);
Expand Down Expand Up @@ -1112,6 +1115,9 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, bool, Pad);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, SpaceToDepth);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, DepthToSpace);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, float, ArgMax);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, double, ArgMax);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, MLFloat16, ArgMax);

#if defined(CUDA_VERSION) && CUDA_VERSION >= 11000
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 13, BFloat16, Add);
Expand Down Expand Up @@ -1593,9 +1599,9 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 10, uint8_t, DequantizeLinear)>,

// opset 11
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, float, ArgMax)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, double, ArgMax)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, MLFloat16, ArgMax)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 11, float, ArgMax)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 11, double, ArgMax)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 11, MLFloat16, ArgMax)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, float, ArgMin)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, double, ArgMin)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, MLFloat16, ArgMin)>,
Expand Down Expand Up @@ -1744,6 +1750,9 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) {

BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, 12, Dropout)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, Einsum)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, 12, float, ArgMax)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, 12, double, ArgMax)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, 12, MLFloat16, ArgMax)>,

// OpSet 13
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 14, Pow)>,
Expand Down Expand Up @@ -1958,6 +1967,9 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, bool, Pad)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, SpaceToDepth)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, DepthToSpace)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, float, ArgMax)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, double, ArgMax)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, MLFloat16, ArgMax)>,

#if defined(CUDA_VERSION) && CUDA_VERSION >= 11000
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 13, BFloat16, Add)>,
Expand Down Expand Up @@ -2139,6 +2151,32 @@ static bool RNNNeedFallbackToCPU(const onnxruntime::Node& node,
return false;
}

static bool ArgMaxNeedFallbackToCPU(const onnxruntime::Node& node) {
// Opset 12 introduced the attribute "select_last_index"
if (node.SinceVersion() >= 12) {
const auto& node_attributes = node.GetAttributes();

for (auto& attr : node_attributes) {
auto& attr_name = attr.first;
auto& attr_value = attr.second;

// CuDNN doesn't support picking the last index in case of encountering
// duplicate max values.
// CuDNN's API doc doesn't mention what happens in case duplicates are encountered,
// but based on testing, the results seem to indicate a "stable" implementation
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's verify this with our Nvidia POC.

// (i.e.) relative ordering is preserved which is the expected behavior when the
// attribute takes on the default value (most commong use-case for this operator).
if ("select_last_index" == attr_name) {
if (attr_value.i() != 0) {
return true;
}
}
}
}

return false;
}

static bool ConvTransposeNeedFallbackToCPU(const onnxruntime::Node& node) {
const auto& node_attributes = node.GetAttributes();
// Check attributes
Expand Down Expand Up @@ -2259,6 +2297,9 @@ CUDAExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph,
} else if ("ConvTranspose" == node.OpType()) {
not_supported = ConvTransposeNeedFallbackToCPU(node);
force_inside = !not_supported;
} else if ("ArgMax" == node.OpType()) {
not_supported = ArgMaxNeedFallbackToCPU(node);
force_inside = !not_supported;
} else if ("Cast" == node.OpType()) {
not_supported = CastNeedFallbackToCPU(node);
// cast is not compute heavy, and may be placed outside
Expand Down
79 changes: 59 additions & 20 deletions onnxruntime/core/providers/cuda/reduction/reduction_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,6 @@ namespace cuda {
(*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType<T>()), \
name<T>);

// CUDA ArgMax/ArgMin doesn't have OpSet12 implementation (with select_last_index attr), keep it in OpSet11 for now.
#define REGISTER_KERNEL_TYPED_11(name, T) \
ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \
name, \
Expand All @@ -122,6 +121,40 @@ namespace cuda {
(*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType<T>()), \
name<T>);

#define REGISTER_ARGMAX_KERNEL_TYPED_13(T) \
ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \
ArgMax, \
kOnnxDomain, \
1, 10, \
T, \
kCudaExecutionProvider, \
(*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType<T>()), \
ArgMax<T>); \
ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \
ArgMax, \
kOnnxDomain, \
11, 11, \
T, \
kCudaExecutionProvider, \
(*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType<T>()), \
ArgMax<T>); \
ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \
ArgMax, \
kOnnxDomain, \
12, 12, \
T, \
kCudaExecutionProvider, \
(*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType<T>()), \
ArgMax<T>); \
ONNX_OPERATOR_TYPED_KERNEL_EX( \
ArgMax, \
kOnnxDomain, \
13, \
T, \
kCudaExecutionProvider, \
(*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType<T>()), \
ArgMax<T>);

// Register with the latest version 13
#define REGISTER_KERNEL_TYPED_13(name, T) \
ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \
Expand Down Expand Up @@ -807,24 +840,24 @@ Status ReduceKernel<allow_multi_axes>::ComputeImpl(OpKernelContext* ctx, cudnnRe
cudnnDataType_t cudnn_type_X = CUDNN_DATA_FLOAT; \
IAllocatorUniquePtr<float> temp_X = GetScratchBuffer<float>(input_count); \
Impl_Cast<CudaT, float>(Stream(), reinterpret_cast<const CudaT*>(X->template Data<T>()), temp_X.get(), X->Shape().Size()); \
\
ORT_RETURN_IF_ERROR(reduce_desc.Set(cudnn_reduce_op, cudnn_type_X, CUDNN_REDUCE_TENSOR_NO_INDICES)); \
ORT_RETURN_IF_ERROR(input_tensor.Set(input_dims_cudnn, cudnn_type_X)); \
ORT_RETURN_IF_ERROR(output_tensor.Set(output_dims_cudnn, cudnn_type_X)); \
CUDNN_RETURN_IF_ERROR( \
cudnnGetReductionIndicesSize(CudnnHandle(), reduce_desc, input_tensor, output_tensor, &indices_bytes)); \
CUDNN_RETURN_IF_ERROR( \
cudnnGetReductionWorkspaceSize(CudnnHandle(), reduce_desc, input_tensor, output_tensor, &workspace_bytes)); \
IAllocatorUniquePtr<uint32_t> indices_cuda = GetScratchBuffer<uint32_t>(indices_bytes); \
IAllocatorUniquePtr<CudaT> workspace_cuda = GetScratchBuffer<CudaT>(workspace_bytes); \
\
const auto one = Consts<float>::One; \
const auto zero = Consts<float>::Zero; \
auto temp_Y = GetScratchBuffer<float>(output_count); \
CUDNN_RETURN_IF_ERROR(cudnnReduceTensor(CudnnHandle(), reduce_desc, indices_cuda.get(), indices_bytes, \
workspace_cuda.get(), workspace_bytes, &one, input_tensor, temp_X.get(), \
&zero, output_tensor, temp_Y.get())); \
\
\
ORT_RETURN_IF_ERROR(reduce_desc.Set(cudnn_reduce_op, cudnn_type_X, CUDNN_REDUCE_TENSOR_NO_INDICES)); \
ORT_RETURN_IF_ERROR(input_tensor.Set(input_dims_cudnn, cudnn_type_X)); \
ORT_RETURN_IF_ERROR(output_tensor.Set(output_dims_cudnn, cudnn_type_X)); \
CUDNN_RETURN_IF_ERROR( \
cudnnGetReductionIndicesSize(CudnnHandle(), reduce_desc, input_tensor, output_tensor, &indices_bytes)); \
CUDNN_RETURN_IF_ERROR( \
cudnnGetReductionWorkspaceSize(CudnnHandle(), reduce_desc, input_tensor, output_tensor, &workspace_bytes)); \
IAllocatorUniquePtr<uint32_t> indices_cuda = GetScratchBuffer<uint32_t>(indices_bytes); \
IAllocatorUniquePtr<CudaT> workspace_cuda = GetScratchBuffer<CudaT>(workspace_bytes); \
\
const auto one = Consts<float>::One; \
const auto zero = Consts<float>::Zero; \
auto temp_Y = GetScratchBuffer<float>(output_count); \
CUDNN_RETURN_IF_ERROR(cudnnReduceTensor(CudnnHandle(), reduce_desc, indices_cuda.get(), indices_bytes, \
workspace_cuda.get(), workspace_bytes, &one, input_tensor, temp_X.get(), \
&zero, output_tensor, temp_Y.get())); \
\
Impl_Cast<float, CudaT>(Stream(), temp_Y.get(), reinterpret_cast<CudaT*>(Y->template MutableData<T>()), output_count); \
\
return Status::OK(); \
Expand Down Expand Up @@ -1014,8 +1047,14 @@ template std::unique_ptr<Tensor> ReduceCompute<MLFloat16, CUDNN_REDUCE_TENSOR_NO
REGISTER_KERNEL_TYPED_11(name, float) \
REGISTER_KERNEL_TYPED_11(name, double)

REGISTER_KERNEL_HFD_11(ArgMax)
REGISTER_KERNEL_HFD_11(ArgMin)

// If supporting select_last_index == 1, please remove
// logic in ArgMaxNeedFallbackToCPU() in cuda_execution_provider.cc
REGISTER_ARGMAX_KERNEL_TYPED_13(MLFloat16)
REGISTER_ARGMAX_KERNEL_TYPED_13(float)
REGISTER_ARGMAX_KERNEL_TYPED_13(double)

REGISTER_KERNEL_HFD(ReduceL1)
REGISTER_KERNEL_HFD(ReduceL2)

Expand Down
10 changes: 9 additions & 1 deletion onnxruntime/core/providers/cuda/reduction/reduction_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,15 @@ class ReduceKernel : public CudaKernel, public ReduceKernelBase<allow_multi_axes
template <typename T>
class ArgMax final : public ReduceKernel<false> {
public:
ArgMax(const OpKernelInfo& info) : ReduceKernel<false>(info) {}
ArgMax(const OpKernelInfo& info) : ReduceKernel<false>(info) {
// The following is just a safety check.
// The logic in ArgMaxNeedFallbackToCPU() makes sure to not assign ArgMax
// nodes with select_last_index == 1 to the CUDA EP.
int64_t select_last_index = 0;
if (info.GetAttr<int64_t>("select_last_index", &select_last_index).IsOK()) {
ORT_ENFORCE(select_last_index == 0, "select_last_index as 1 is not supported on CUDA");
}
}

Status ComputeInternal(OpKernelContext* ctx) const override {
return ComputeImpl<T, CUDNN_REDUCE_TENSOR_FLATTENED_INDICES>(ctx, CUDNN_REDUCE_TENSOR_MAX);
Expand Down
Loading