Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 19 additions & 4 deletions llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@ struct ConvertBuiltin {
bool IsSaturated;
bool IsRounded;
bool IsBfloat16;
bool IsTF32;
FPRoundingMode::FPRoundingMode RoundingMode;
};

Expand Down Expand Up @@ -230,6 +231,7 @@ std::string lookupBuiltinNameHelper(StringRef DemangledCall,
// - "__spirv_SubgroupImageMediaBlockReadINTEL"
// - "__spirv_SubgroupImageMediaBlockWriteINTEL"
// - "__spirv_Convert"
// - "__spirv_Round"
// - "__spirv_UConvert"
// - "__spirv_SConvert"
// - "__spirv_FConvert"
Expand All @@ -242,7 +244,7 @@ std::string lookupBuiltinNameHelper(StringRef DemangledCall,
"SDotKHR|SUDotKHR|SDotAccSatKHR|UDotAccSatKHR|SUDotAccSatKHR|"
"ReadClockKHR|SubgroupBlockReadINTEL|SubgroupImageBlockReadINTEL|"
"SubgroupImageMediaBlockReadINTEL|SubgroupImageMediaBlockWriteINTEL|"
"Convert|"
"Convert|Round|"
"UConvert|SConvert|FConvert|SatConvert)[^_]*)(_R[^_]*_?(\\w+)?.*)?");
std::smatch Match;
if (std::regex_match(BuiltinName, Match, SpvWithR) && Match.size() > 1) {
Expand Down Expand Up @@ -697,7 +699,8 @@ static bool buildAtomicStoreInst(const SPIRV::IncomingCall *Call,
MachineIRBuilder &MIRBuilder,
SPIRVGlobalRegistry *GR) {
if (Call->isSpirvOp())
return buildOpFromWrapper(MIRBuilder, SPIRV::OpAtomicStore, Call, Register(0));
return buildOpFromWrapper(MIRBuilder, SPIRV::OpAtomicStore, Call,
Register(0));

Register ScopeRegister =
buildConstantIntReg32(SPIRV::Scope::Device, MIRBuilder, GR);
Expand Down Expand Up @@ -2677,8 +2680,20 @@ static bool generateConvertInst(const StringRef DemangledCall,
}
} else if (GR->isScalarOrVectorOfType(Call->ReturnRegister,
SPIRV::OpTypeFloat)) {
// Float -> Float
Opcode = SPIRV::OpFConvert;
if (Builtin->IsTF32) {
const auto *ST = static_cast<const SPIRVSubtarget *>(
&MIRBuilder.getMF().getSubtarget());
if (!ST->canUseExtension(
SPIRV::Extension::SPV_INTEL_tensor_float32_conversion))
NeedExtMsg = "SPV_INTEL_tensor_float32_conversion";
IsRightComponentsNumber =
GR->getScalarOrVectorComponentCount(Call->Arguments[0]) ==
GR->getScalarOrVectorComponentCount(Call->ReturnRegister);
Opcode = SPIRV::OpRoundFToTF32INTEL;
} else {
// Float -> Float
Opcode = SPIRV::OpFConvert;
}
}
}

Expand Down
23 changes: 22 additions & 1 deletion llvm/lib/Target/SPIRV/SPIRVBuiltins.td
Original file line number Diff line number Diff line change
Expand Up @@ -1461,6 +1461,8 @@ class ConvertBuiltin<string name, InstructionSet set> {
bit IsRounded = !not(!eq(!find(name, "_rt"), -1));
bit IsBfloat16 = !or(!not(!eq(!find(name, "BF16"), -1)),
!not(!eq(!find(name, "bfloat16"), -1)));
bit IsTF32 = !or(!not(!eq(!find(name, "TF32"), -1)),
!not(!eq(!find(name, "tensor_float32"), -1)));
FPRoundingMode RoundingMode = !cond(!not(!eq(!find(name, "_rte"), -1)) : RTE,
!not(!eq(!find(name, "_rtz"), -1)) : RTZ,
!not(!eq(!find(name, "_rtp"), -1)) : RTP,
Expand All @@ -1472,7 +1474,7 @@ class ConvertBuiltin<string name, InstructionSet set> {
def ConvertBuiltins : GenericTable {
let FilterClass = "ConvertBuiltin";
let Fields = ["Name", "Set", "IsDestinationSigned", "IsSaturated",
"IsRounded", "IsBfloat16", "RoundingMode"];
"IsRounded", "IsBfloat16", "IsTF32", "RoundingMode"];
string TypeOf_Set = "InstructionSet";
string TypeOf_RoundingMode = "FPRoundingMode";
}
Expand Down Expand Up @@ -1556,6 +1558,25 @@ foreach conv = ["FToBF16INTEL", "BF16ToFINTEL"] in {
def : ConvertBuiltin<!strconcat("__spirv_Convert", conv), OpenCL_std>;
}

// cl_intel_tensor_float32_conversions / SPV_INTEL_tensor_float32_conversion
// Multiclass used to define at the same time both a demangled builtin record
// and a corresponding convert builtin record.
multiclass DemangledTF32RoundBuiltin<string name1, string name2> {
// Create records for scalar and vector conversions.
foreach i = ["", "2", "3", "4", "8", "16"] in {
def : DemangledBuiltin<!strconcat("intel_round_", name1, i, name2, i), OpenCL_std, Convert, 1, 1>;
def : ConvertBuiltin<!strconcat("intel_round_", name1, i, name2, i), OpenCL_std>;
}
}

defm : DemangledTF32RoundBuiltin<"tensor_float32", "_as_float">;
defm : DemangledTF32RoundBuiltin<"as_tensor_float32", "_float">;

foreach conv = ["FToTF32INTEL"] in {
def : DemangledBuiltin<!strconcat("__spirv_Round", conv), OpenCL_std, Convert, 1, 1>;
def : ConvertBuiltin<!strconcat("__spirv_Round", conv), OpenCL_std>;
}

//===----------------------------------------------------------------------===//
// Class defining a vector data load/store builtin record used for lowering
// into OpExtInst instruction.
Expand Down
4 changes: 3 additions & 1 deletion llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,9 @@ static const std::map<std::string, SPIRV::Extension::Extension, std::less<>>
SPIRV::Extension::Extension::SPV_INTEL_2d_block_io},
{"SPV_INTEL_int4", SPIRV::Extension::Extension::SPV_INTEL_int4},
{"SPV_KHR_float_controls2",
SPIRV::Extension::Extension::SPV_KHR_float_controls2}};
SPIRV::Extension::Extension::SPV_KHR_float_controls2},
{"SPV_INTEL_tensor_float32_conversion",
SPIRV::Extension::Extension::SPV_INTEL_tensor_float32_conversion}};

bool SPIRVExtensionsParser::parse(cl::Option &O, StringRef ArgName,
StringRef ArgValue,
Expand Down
3 changes: 3 additions & 0 deletions llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
Original file line number Diff line number Diff line change
Expand Up @@ -445,6 +445,9 @@ def OpCrossWorkgroupCastToPtrINTEL : UnOp<"OpCrossWorkgroupCastToPtrINTEL", 5938
def OpConvertFToBF16INTEL : UnOp<"OpConvertFToBF16INTEL", 6116>;
def OpConvertBF16ToFINTEL : UnOp<"OpConvertBF16ToFINTEL", 6117>;

// SPV_INTEL_tensor_float32_conversion
def OpRoundFToTF32INTEL : UnOp<"OpRoundFToTF32INTEL", 6426>;

// 3.42.12 Composite Instructions

def OpVectorExtractDynamic: Op<77, (outs ID:$res), (ins TYPE:$type, vID:$vec, ID:$idx),
Expand Down
7 changes: 7 additions & 0 deletions llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1564,6 +1564,13 @@ void addInstrRequirements(const MachineInstr &MI,
Reqs.addCapability(SPIRV::Capability::BFloat16ConversionINTEL);
}
break;
case SPIRV::OpRoundFToTF32INTEL:
if (ST.canUseExtension(
SPIRV::Extension::SPV_INTEL_tensor_float32_conversion)) {
Reqs.addExtension(SPIRV::Extension::SPV_INTEL_tensor_float32_conversion);
Reqs.addCapability(SPIRV::Capability::TensorFloat32RoundingINTEL);
}
break;
case SPIRV::OpVariableLengthArrayINTEL:
case SPIRV::OpSaveMemoryINTEL:
case SPIRV::OpRestoreMemoryINTEL:
Expand Down
2 changes: 2 additions & 0 deletions llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,7 @@ defm SPV_INTEL_subgroup_matrix_multiply_accumulate : ExtensionOperand<121>;
defm SPV_INTEL_2d_block_io : ExtensionOperand<122>;
defm SPV_INTEL_int4 : ExtensionOperand<123>;
defm SPV_KHR_float_controls2 : ExtensionOperand<124>;
defm SPV_INTEL_tensor_float32_conversion : ExtensionOperand<125>;

//===----------------------------------------------------------------------===//
// Multiclass used to define Capabilities enum values and at the same time
Expand Down Expand Up @@ -529,6 +530,7 @@ defm Subgroup2DBlockTransformINTEL : CapabilityOperand<6229, 0, 0, [SPV_INTEL_2d
defm Subgroup2DBlockTransposeINTEL : CapabilityOperand<6230, 0, 0, [SPV_INTEL_2d_block_io], [Subgroup2DBlockIOINTEL]>;
defm Int4TypeINTEL : CapabilityOperand<5112, 0, 0, [SPV_INTEL_int4], []>;
defm Int4CooperativeMatrixINTEL : CapabilityOperand<5114, 0, 0, [SPV_INTEL_int4], [Int4TypeINTEL, CooperativeMatrixKHR]>;
defm TensorFloat32RoundingINTEL : CapabilityOperand<6425, 0, 0, [SPV_INTEL_tensor_float32_conversion], []>;

//===----------------------------------------------------------------------===//
// Multiclass used to define SourceLanguage enum values and at the same time
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
; RUN: not llc -O0 -mtriple=spirv32-unknown-unknown --spirv-ext=+SPV_INTEL_tensor_float32_conversion %s -o %t.spvt 2>&1 | FileCheck %s --check-prefix=CHECK-ERROR
; CHECK-ERROR: result and argument must have the same number of components

target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-n8:16:32:64"
target triple = "spir64-unknown-unknown"

define spir_func void @test(<8 x float> %in) {
%res = tail call spir_func float @_Z25__spirv_RoundFToTF32INTELDv8_f(<8 x float> %in)
ret void
}

declare spir_func float @_Z25__spirv_RoundFToTF32INTELDv8_f(<8 x float>)
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
; RUN: not llc -O0 -mtriple=spirv32-unknown-unknown --spirv-ext=+SPV_INTEL_tensor_float32_conversion %s -o %t.spvt 2>&1 | FileCheck %s --check-prefix=CHECK-ERROR
; CHECK-ERROR: result and argument must have the same number of components

target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-n8:16:32:64"
target triple = "spir64-unknown-unknown"

define spir_func void @test(<8 x float> %in) {
%res = tail call spir_func <4 x float> @_Z25__spirv_RoundFToTF32INTELDv8_f(<8 x float> %in)
ret void
}

declare spir_func <4 x float> @_Z25__spirv_RoundFToTF32INTELDv8_f(<8 x float>)
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
; RUN: llc -O0 -mtriple=spirv32-unknown-unknown --spirv-ext=+SPV_INTEL_tensor_float32_conversion %s -o - | FileCheck %s
; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown --spirv-ext=+SPV_INTEL_tensor_float32_conversion %s -o - -filetype=obj | spirv-val %}

; RUN: not llc -O0 -mtriple=spirv32-unknown-unknown %s -o %t.spvt 2>&1 | FileCheck %s --check-prefix=CHECK-ERROR
; CHECK-ERROR: the builtin requires the following SPIR-V extension: SPV_INTEL_tensor_float32_conversion

; CHECK: OpCapability TensorFloat32RoundingINTEL
; CHECK: OpExtension "SPV_INTEL_tensor_float32_conversion"

; CHECK-DAG: %[[VoidTy:.*]] = OpTypeVoid
; CHECK-DAG: %[[FP32Ty:.*]] = OpTypeFloat 32
; CHECK-DAG: %[[VecFloat2:.*]] = OpTypeVector %[[FP32Ty]] 2
; CHECK-DAG: %[[VecFloat3:.*]] = OpTypeVector %[[FP32Ty]] 3
; CHECK-DAG: %[[VecFloat4:.*]] = OpTypeVector %[[FP32Ty]] 4
; CHECK-DAG: %[[VecFloat8:.*]] = OpTypeVector %[[FP32Ty]] 8
; CHECK-DAG: %[[VecFloat16:.*]] = OpTypeVector %[[FP32Ty]] 16
; CHECK-DAG: %[[FloatConstId:.*]] = OpConstant %[[FP32Ty]] 1.5

; CHECK: OpFunction %[[VoidTy]]
; CHECK: %[[FP32ValId:.*]] = OpFunctionParameter %[[FP32Ty]]
; CHECK: %[[FP32v8ValId:.*]] = OpFunctionParameter %[[VecFloat8]]
; CHECK: OpRoundFToTF32INTEL %[[FP32Ty]] %[[FP32ValId]]
; CHECK: OpRoundFToTF32INTEL %[[VecFloat8]] %[[FP32v8ValId]]
; CHECK: OpRoundFToTF32INTEL %[[FP32Ty]] %[[FloatConstId]]

; CHECK: OpRoundFToTF32INTEL %[[FP32Ty]]
; CHECK: OpRoundFToTF32INTEL %[[VecFloat2]]
; CHECK: OpRoundFToTF32INTEL %[[VecFloat3]]
; CHECK: OpRoundFToTF32INTEL %[[VecFloat4]]
; CHECK: OpRoundFToTF32INTEL %[[VecFloat8]]
; CHECK: OpRoundFToTF32INTEL %[[VecFloat16]]

target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-n8:16:32:64"
target triple = "spir64-unknown-unknown"

define spir_func void @test(float %a, <8 x float> %in) {
%res1 = tail call spir_func float @_Z25__spirv_RoundFToTF32INTELf(float %a)
%res2 = tail call spir_func <8 x float> @_Z25__spirv_RoundFToTF32INTELDv8_f(<8 x float> %in)
%res3 = tail call spir_func float @_Z25__spirv_RoundFToTF32INTELf(float 1.500000e+00)
ret void
}

declare spir_func float @_Z25__spirv_RoundFToTF32INTELf(float)
declare spir_func <8 x float> @_Z25__spirv_RoundFToTF32INTELDv8_f(<8 x float>)

define dso_local spir_kernel void @test_ocl(float %a) {
entry:
%res4 = call spir_func float @_Z35intel_round_as_tensor_float32_floatt(float 0.000000e+00)
%res5 = call spir_func <2 x float> @_Z37intel_round_as_tensor_float322_float2Dv2_t(<2 x float> zeroinitializer)
%res6 = call spir_func <3 x float> @_Z37intel_round_as_tensor_float323_float3Dv3_t(<3 x float> zeroinitializer)
%res7 = call spir_func <4 x float> @_Z37intel_round_as_tensor_float324_float4Dv4_t(<4 x float> zeroinitializer)
%res8 = call spir_func <8 x float> @_Z37intel_round_as_tensor_float328_float8Dv8_t(<8 x float> zeroinitializer)
%res9 = call spir_func <16 x float> @_Z39intel_round_as_tensor_float3216_float16Dv16_t(<16 x float> zeroinitializer)
ret void
}

declare spir_func float @_Z35intel_round_as_tensor_float32_floatt(float)
declare spir_func <2 x float> @_Z37intel_round_as_tensor_float322_float2Dv2_t(<2 x float>)
declare spir_func <3 x float> @_Z37intel_round_as_tensor_float323_float3Dv3_t(<3 x float>)
declare spir_func <4 x float> @_Z37intel_round_as_tensor_float324_float4Dv4_t(<4 x float>)
declare spir_func <8 x float> @_Z37intel_round_as_tensor_float328_float8Dv8_t(<8 x float>)
declare spir_func <16 x float> @_Z39intel_round_as_tensor_float3216_float16Dv16_t(<16 x float>)