Skip to content

Commit 2d20337

Browse files
Regenerate MLIR Bindings (#1084)
Co-authored-by: enzyme-ci-bot[bot] <78882869+enzyme-ci-bot[bot]@users.noreply.github.com>
1 parent 0c1b0e3 commit 2d20337

File tree

4 files changed

+207
-22
lines changed

4 files changed

+207
-22
lines changed

src/mlir/Dialects/Llvm.jl

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -726,6 +726,45 @@ function mlir_constant(; res::IR.Type, value, location=Location())
726726
)
727727
end
728728

729+
"""
730+
`dso_local_equivalent`
731+
732+
Creates an SSA value containing a pointer to a global value (function or
733+
alias to function). It represents a function which is functionally
734+
equivalent to a given function, but is always defined in the current
735+
linkage unit. The target function may not have `extern_weak` linkage.
736+
737+
Examples:
738+
739+
```mlir
740+
llvm.mlir.global external constant @const() : i64 {
741+
%0 = llvm.mlir.addressof @const : !llvm.ptr
742+
%1 = llvm.ptrtoint %0 : !llvm.ptr to i64
743+
%2 = llvm.dso_local_equivalent @func : !llvm.ptr
744+
%4 = llvm.ptrtoint %2 : !llvm.ptr to i64
745+
llvm.return %4 : i64
746+
}
747+
```
748+
"""
749+
function dso_local_equivalent(; res::IR.Type, function_name, location=Location())
750+
op_ty_results = IR.Type[res,]
751+
operands = Value[]
752+
owned_regions = Region[]
753+
successors = Block[]
754+
attributes = NamedAttribute[namedattribute("function_name", function_name),]
755+
756+
return create_operation(
757+
"llvm.dso_local_equivalent",
758+
location;
759+
operands,
760+
owned_regions,
761+
successors,
762+
attributes,
763+
results=op_ty_results,
764+
result_inference=false,
765+
)
766+
end
767+
729768
function extractelement(
730769
vector::Value, position::Value; res=nothing::Union{Nothing,IR.Type}, location=Location()
731770
)

src/mlir/Dialects/Nvvm.jl

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -304,6 +304,37 @@ function breakpoint(; location=Location())
304304
)
305305
end
306306

307+
"""
308+
`st_bulk`
309+
310+
Initializes a region of shared memory at the address given by `addr`.
311+
The `size` operand specifies the number of bytes to initialize and must be
312+
a multiple of 8.
313+
The `initVal` operand specifies the value to initialize the memory to. The
314+
only supported value is 0.
315+
316+
[For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-st-bulk)
317+
"""
318+
function st_bulk(addr::Value, size::Value; initVal=nothing, location=Location())
319+
op_ty_results = IR.Type[]
320+
operands = Value[addr, size]
321+
owned_regions = Region[]
322+
successors = Block[]
323+
attributes = NamedAttribute[]
324+
!isnothing(initVal) && push!(attributes, namedattribute("initVal", initVal))
325+
326+
return create_operation(
327+
"nvvm.st.bulk",
328+
location;
329+
operands,
330+
owned_regions,
331+
successors,
332+
attributes,
333+
results=op_ty_results,
334+
result_inference=false,
335+
)
336+
end
337+
307338
function read_ptx_sreg_clock64(; res::IR.Type, location=Location())
308339
op_ty_results = IR.Type[res,]
309340
operands = Value[]

src/mlir/Dialects/Triton.jl

Lines changed: 26 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -28,13 +28,20 @@ symbol reference attribute named \"callee\".
2828
```
2929
"""
3030
function call(
31-
operands::Vector{Value}; result_0::Vector{IR.Type}, callee, location=Location()
31+
operands::Vector{Value};
32+
result_0::Vector{IR.Type},
33+
callee,
34+
arg_attrs=nothing,
35+
res_attrs=nothing,
36+
location=Location(),
3237
)
3338
op_ty_results = IR.Type[result_0...,]
3439
operands = Value[operands...,]
3540
owned_regions = Region[]
3641
successors = Block[]
3742
attributes = NamedAttribute[namedattribute("callee", callee),]
43+
!isnothing(arg_attrs) && push!(attributes, namedattribute("arg_attrs", arg_attrs))
44+
!isnothing(res_attrs) && push!(attributes, namedattribute("res_attrs", res_attrs))
3845

3946
return create_operation(
4047
"tt.call",
@@ -459,39 +466,39 @@ end
459466
"""
460467
`dot_scaled`
461468
462-
\$d = matrix_multiply(scale(\$lhs, \$lhs_scale), scale(rlhs, \$rhs_scale)) + \$c.
469+
\$d = matrix_multiply(scale(\$a, \$a_scale), scale(\$b, \$b_scale)) + \$c.
463470
Where scale(x, s) is a function that applies the scale per block following microscaling spec.
464471
"""
465472
function dot_scaled(
466-
lhs::Value,
467-
rhs::Value,
473+
a::Value,
474+
b::Value,
468475
c::Value,
469-
lhs_scale=nothing::Union{Nothing,Value};
470-
rhs_scale=nothing::Union{Nothing,Value},
476+
a_scale=nothing::Union{Nothing,Value};
477+
b_scale=nothing::Union{Nothing,Value},
471478
d::IR.Type,
472-
lhs_type,
473-
rhs_type,
479+
a_elem_type,
480+
b_elem_type,
474481
fastMath,
475482
location=Location(),
476483
)
477484
op_ty_results = IR.Type[d,]
478-
operands = Value[lhs, rhs, c]
485+
operands = Value[a, b, c]
479486
owned_regions = Region[]
480487
successors = Block[]
481488
attributes = NamedAttribute[
482-
namedattribute("lhs_type", lhs_type),
483-
namedattribute("rhs_type", rhs_type),
489+
namedattribute("a_elem_type", a_elem_type),
490+
namedattribute("b_elem_type", b_elem_type),
484491
namedattribute("fastMath", fastMath),
485492
]
486-
!isnothing(lhs_scale) && push!(operands, lhs_scale)
487-
!isnothing(rhs_scale) && push!(operands, rhs_scale)
493+
!isnothing(a_scale) && push!(operands, a_scale)
494+
!isnothing(b_scale) && push!(operands, b_scale)
488495
push!(attributes, operandsegmentsizes([
489496
1,
490497
1,
491498
1,
492-
if (lhs_scale == nothing)
499+
if (a_scale == nothing)
493500
0
494-
elseif 1(rhs_scale == nothing)
501+
elseif 1(b_scale == nothing)
495502
0
496503
else
497504
1
@@ -973,15 +980,12 @@ shape 4x8x2xf32.
973980
Because Triton tensors always have a power-of-two number of elements,
974981
the two input tensors must have the same shape.
975982
"""
976-
function join(
977-
lhs::Value, rhs::Value; result=nothing::Union{Nothing,IR.Type}, location=Location()
978-
)
979-
op_ty_results = IR.Type[]
983+
function join(lhs::Value, rhs::Value; result::IR.Type, location=Location())
984+
op_ty_results = IR.Type[result,]
980985
operands = Value[lhs, rhs]
981986
owned_regions = Region[]
982987
successors = Block[]
983988
attributes = NamedAttribute[]
984-
!isnothing(result) && push!(op_ty_results, result)
985989

986990
return create_operation(
987991
"tt.join",
@@ -990,8 +994,8 @@ function join(
990994
owned_regions,
991995
successors,
992996
attributes,
993-
results=(length(op_ty_results) == 0 ? nothing : op_ty_results),
994-
result_inference=(length(op_ty_results) == 0 ? true : false),
997+
results=op_ty_results,
998+
result_inference=false,
995999
)
9961000
end
9971001

src/mlir/libMLIR_h.jl

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7775,6 +7775,101 @@ function mlirUniformQuantizedPerAxisTypeIsFixedPoint(type)
77757775
@ccall mlir_c.mlirUniformQuantizedPerAxisTypeIsFixedPoint(type::MlirType)::Bool
77767776
end
77777777

7778+
"""
7779+
mlirTypeIsAUniformQuantizedSubChannelType(type)
7780+
7781+
Returns `true` if the given type is a UniformQuantizedSubChannel.
7782+
"""
7783+
function mlirTypeIsAUniformQuantizedSubChannelType(type)
7784+
@ccall mlir_c.mlirTypeIsAUniformQuantizedSubChannelType(type::MlirType)::Bool
7785+
end
7786+
7787+
"""
7788+
mlirUniformQuantizedSubChannelTypeGet(flags, storageType, expressedType, scalesAttr, zeroPointsAttr, blockSizeInfoLength, quantizedDimensions, blockSizes, storageTypeMin, storageTypeMax)
7789+
7790+
Creates a UniformQuantizedSubChannelType with the given parameters.
7791+
7792+
The type is owned by the context. `scalesAttr` and `zeroPointsAttr` must be DenseElementsAttrs. `quantizedDimensions` and `blockSizes` point to `blockSizeInfoLength` number of elements, describing respectively the quantization axis and corresponding block size.
7793+
"""
7794+
function mlirUniformQuantizedSubChannelTypeGet(
7795+
flags,
7796+
storageType,
7797+
expressedType,
7798+
scalesAttr,
7799+
zeroPointsAttr,
7800+
blockSizeInfoLength,
7801+
quantizedDimensions,
7802+
blockSizes,
7803+
storageTypeMin,
7804+
storageTypeMax,
7805+
)
7806+
@ccall mlir_c.mlirUniformQuantizedSubChannelTypeGet(
7807+
flags::Cuint,
7808+
storageType::MlirType,
7809+
expressedType::MlirType,
7810+
scalesAttr::MlirAttribute,
7811+
zeroPointsAttr::MlirAttribute,
7812+
blockSizeInfoLength::intptr_t,
7813+
quantizedDimensions::Ptr{Int32},
7814+
blockSizes::Ptr{Int64},
7815+
storageTypeMin::Int64,
7816+
storageTypeMax::Int64,
7817+
)::MlirType
7818+
end
7819+
7820+
"""
7821+
mlirUniformQuantizedSubChannelTypeGetNumBlockSizes(type)
7822+
7823+
Returns the number of block sizes provided in type.
7824+
"""
7825+
function mlirUniformQuantizedSubChannelTypeGetNumBlockSizes(type)
7826+
@ccall mlir_c.mlirUniformQuantizedSubChannelTypeGetNumBlockSizes(
7827+
type::MlirType
7828+
)::intptr_t
7829+
end
7830+
7831+
"""
7832+
mlirUniformQuantizedSubChannelTypeGetQuantizedDimension(type, pos)
7833+
7834+
Returns the quantized dimension at the given position.
7835+
"""
7836+
function mlirUniformQuantizedSubChannelTypeGetQuantizedDimension(type, pos)
7837+
@ccall mlir_c.mlirUniformQuantizedSubChannelTypeGetQuantizedDimension(
7838+
type::MlirType, pos::intptr_t
7839+
)::Int32
7840+
end
7841+
7842+
"""
7843+
mlirUniformQuantizedSubChannelTypeGetBlockSize(type, pos)
7844+
7845+
Returns the block size at the given position.
7846+
"""
7847+
function mlirUniformQuantizedSubChannelTypeGetBlockSize(type, pos)
7848+
@ccall mlir_c.mlirUniformQuantizedSubChannelTypeGetBlockSize(
7849+
type::MlirType, pos::intptr_t
7850+
)::Int64
7851+
end
7852+
7853+
"""
7854+
mlirUniformQuantizedSubChannelTypeGetScales(type)
7855+
7856+
Returns the scales of the quantized type.
7857+
"""
7858+
function mlirUniformQuantizedSubChannelTypeGetScales(type)
7859+
@ccall mlir_c.mlirUniformQuantizedSubChannelTypeGetScales(type::MlirType)::MlirAttribute
7860+
end
7861+
7862+
"""
7863+
mlirUniformQuantizedSubChannelTypeGetZeroPoints(type)
7864+
7865+
Returns the zero-points of the quantized type.
7866+
"""
7867+
function mlirUniformQuantizedSubChannelTypeGetZeroPoints(type)
7868+
@ccall mlir_c.mlirUniformQuantizedSubChannelTypeGetZeroPoints(
7869+
type::MlirType
7870+
)::MlirAttribute
7871+
end
7872+
77787873
"""
77797874
mlirTypeIsACalibratedQuantizedType(type)
77807875
@@ -10369,6 +10464,8 @@ function sdyOpShardingRuleAttrGet(
1036910464
needReplicationFactors,
1037010465
nPermutationFactors,
1037110466
permutationFactors,
10467+
nBlockedPropagationFactors,
10468+
blockedPropagationFactors,
1037210469
isCustomRule,
1037310470
)
1037410471
@ccall mlir_c.sdyOpShardingRuleAttrGet(
@@ -10385,6 +10482,8 @@ function sdyOpShardingRuleAttrGet(
1038510482
needReplicationFactors::Ptr{Int64},
1038610483
nPermutationFactors::intptr_t,
1038710484
permutationFactors::Ptr{Int64},
10485+
nBlockedPropagationFactors::Int64,
10486+
blockedPropagationFactors::Ptr{Int64},
1038810487
isCustomRule::Bool,
1038910488
)::MlirAttribute
1039010489
end
@@ -10459,6 +10558,18 @@ function sdyOpShardingRuleAttrGetPermutationFactorsElem(attr, pos)
1045910558
)::Int64
1046010559
end
1046110560

10561+
function sdyOpShardingRuleAttrGetBlockedPropagationFactorsSize(attr)
10562+
@ccall mlir_c.sdyOpShardingRuleAttrGetBlockedPropagationFactorsSize(
10563+
attr::MlirAttribute
10564+
)::intptr_t
10565+
end
10566+
10567+
function sdyOpShardingRuleAttrGetBlockedPropagationFactorsElem(attr, pos)
10568+
@ccall mlir_c.sdyOpShardingRuleAttrGetBlockedPropagationFactorsElem(
10569+
attr::MlirAttribute, pos::intptr_t
10570+
)::Int64
10571+
end
10572+
1046210573
function sdyAttributeIsAManualAxesAttr(attr)
1046310574
@ccall mlir_c.sdyAttributeIsAManualAxesAttr(attr::MlirAttribute)::Bool
1046410575
end

0 commit comments

Comments
 (0)