Skip to content

Commit 65e9976

Browse files
wsmosesWilliam Mosesjumerckxgithub-actions[bot]
authored
Interp2 (EnzymeAD#365)
* WIP: kernels * more files * fix * wip * wqtmp * wip * inc * continuing * wip * more work * inf rec * fix * overload working * continuing * continuing * push * fix `call_with_reactant_generator` for Julia 1.11 (EnzymeAD#359) * conversion * continuing * Cleanup * Apply suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Delete test/cuda.jl * fixup * Apply suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * fix apply * indep of change * minor fix in name * Update utils.jl * Interp take 2 * continuing adentures * delcode * fix * tmp * make * fix * cleanup * continuing * more working * further simplify * fx * more improvements * minus show * less prints * even fewer * confusion * tmp * force clean * force oc * clean * Rewrite * fixup * fix * fix * fix * fixup * fix * wip * safe prints * fix * fix * stackoverflow * cleanup * dyindex * rt * continue * clean * fix * fix * fix * fix * fixup * fix * fix * capture oc * compile perf * v1.11 fix * other way 'round * formatting --------- Co-authored-by: William Moses <[email protected]> Co-authored-by: jumerckx <[email protected]> Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: jumerckx <[email protected]>
1 parent 73899f5 commit 65e9976

22 files changed

+1578
-960
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,14 +41,14 @@ Adapt = "4"
4141
ArrayInterface = "7.10"
4242
CEnum = "0.4, 0.5"
4343
Downloads = "1.6"
44-
Enzyme = "0.13.21"
44+
Enzyme = "0.13.22"
4545
EnzymeCore = "0.8.8"
4646
GPUArraysCore = "0.1.6, 0.2"
4747
LinearAlgebra = "1.10"
4848
NNlib = "0.9.26"
4949
OrderedCollections = "1"
5050
Preferences = "1.4"
51-
ReactantCore = "0.1.2"
51+
ReactantCore = "0.1.3"
5252
Reactant_jll = "0.0.26"
5353
Scratch = "1.2"
5454
Statistics = "1.10"

deps/ReactantExtra/API.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -376,6 +376,16 @@ extern "C" MlirModule ConvertLLVMToMLIR(LLVMModuleRef lmod, MlirContext cctx) {
376376
return wrap(res);
377377
}
378378

379+
#include "llvm/IRReader/IRReader.h"
380+
extern "C" MlirModule ConvertLLVMStrToMLIR(const char* lmod, MlirContext cctx) {
381+
LLVMContext Context;
382+
SMDiagnostic Err;
383+
auto llvmModule = llvm::parseIR(llvm::MemoryBufferRef(lmod, "conversion"), Err, Context);
384+
mlir::MLIRContext &context = *unwrap(cctx);
385+
auto res = mlir::translateLLVMIRToModule(std::move(llvmModule), &context, /*emitExpensiveWarnings*/false, /*dropDICompositeElements*/false).release();
386+
return wrap(res);
387+
}
388+
379389

380390
/* Note that this */
381391
extern "C" xla::PjRtLoadedExecutable* ClientCompile(PjRtClient * client, MlirModule cmod) {

deps/ReactantExtra/BUILD

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -450,6 +450,8 @@ cc_library(
450450
"@llvm-project//mlir:SCFDialect",
451451
"@llvm-project//mlir:TransformDialect",
452452
"@llvm-project//mlir:Transforms",
453+
454+
"@llvm-project//llvm:IRReader",
453455
"@llvm-project//llvm:Support",
454456
"@llvm-project//llvm:AArch64AsmParser",
455457
"@llvm-project//llvm:AArch64CodeGen",

ext/ReactantNNlibExt.jl

Lines changed: 11 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,10 @@ module ReactantNNlibExt
22

33
using NNlib
44
using GPUArraysCore: @allowscalar
5-
using Reactant:
6-
Reactant,
7-
Ops,
8-
TracedRArray,
9-
AnyTracedRArray,
10-
materialize_traced_array,
11-
MLIR,
12-
TracedRNumber,
13-
get_mlir_data,
14-
set_mlir_data!
5+
using Reactant: Reactant, Ops, TracedRArray, AnyTracedRArray, MLIR, TracedRNumber
6+
7+
using Reactant.TracedUtils: materialize_traced_array, get_mlir_data, set_mlir_data!
8+
159
using ReactantCore: @trace
1610
using LinearAlgebra: LinearAlgebra, triu
1711

@@ -238,9 +232,9 @@ function NNlib.batched_mul!(
238232
if size(x, 3) != size(y, 3)
239233
B = max(size(x, 3), size(y, 3))
240234
if size(x, 3) == 1
241-
x = Reactant.broadcast_to_size(x, (size(x, 1), size(x, 2), B))
235+
x = Reactant.TracedUtils.broadcast_to_size(x, (size(x, 1), size(x, 2), B))
242236
elseif size(y, 3) == 1
243-
y = Reactant.broadcast_to_size(y, (size(y, 1), size(y, 2), B))
237+
y = Reactant.TracedUtils.broadcast_to_size(y, (size(y, 1), size(y, 2), B))
244238
end
245239
end
246240

@@ -250,9 +244,9 @@ function NNlib.batched_mul!(
250244
if size(x, 1) != size(y, 1)
251245
B = max(size(x, 1), size(y, 1))
252246
if size(x, 1) == 1
253-
x = Reactant.broadcast_to_size(x, (B, size(x, 2), size(x, 3)))
247+
x = Reactant.TracedUtils.broadcast_to_size(x, (B, size(x, 2), size(x, 3)))
254248
elseif size(y, 1) == 1
255-
y = Reactant.broadcast_to_size(y, (B, size(y, 2), size(y, 3)))
249+
y = Reactant.TracedUtils.broadcast_to_size(y, (B, size(y, 2), size(y, 3)))
256250
end
257251
end
258252

@@ -270,7 +264,7 @@ end
270264
function NNlib.pad_constant(
271265
x::AnyTracedRArray{T,N}, pad::NTuple{N,Tuple{Int,Int}}, value
272266
) where {T,N}
273-
value = Reactant.promote_to(TracedRNumber{T}, value)
267+
value = Reactant.TracedUtils.promote_to(TracedRNumber{T}, value)
274268
low = [i[1] for i in pad]
275269
high = [i[2] for i in pad]
276270
interior = [0 for i in pad]
@@ -329,7 +323,8 @@ function NNlib.gather!(dst::TracedRArray, src::AnyTracedRArray, idxs::AbstractAr
329323
start_sizes = ntuple(i -> size(src, i), dims)
330324
results = map(CartesianIndices(idxs)) do k
331325
res = @allowscalar src[colons..., Tuple(idxs[k])...]
332-
res isa TracedRNumber && (res = Reactant.broadcast_to_size(res, (1,)))
326+
res isa TracedRNumber &&
327+
(res = Reactant.TracedUtils.broadcast_to_size(res, (1,)))
333328
return reshape(res, start_sizes..., :)
334329
end
335330
res = reshape(cat(results...; dims=(dims + 1)), size(dst))

ext/ReactantStatisticsExt.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
module ReactantStatisticsExt
22

3-
using Reactant: AnyTracedRArray, materialize_traced_array
3+
using Reactant: AnyTracedRArray
4+
using Reactant.TracedUtils: materialize_traced_array
45
using Statistics: Statistics
56

67
function Statistics.mean(A::AnyTracedRArray{T,N}; dims=:) where {T,N}

ext/ReactantYaoBlocksExt.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
module ReactantYaoBlocksExt
22

33
using Reactant
4+
using Reactant.TracedUtils: broadcast_to_size
45
using YaoBlocks
56

67
function YaoBlocks.mat(
78
::Type{T}, R::RotationGate{D,Reactant.TracedRNumber{S},<:XGate}
89
) where {D,T,S}
9-
M = Reactant.broadcast_to_size(zero(T), (2, 2))
10+
M = broadcast_to_size(zero(T), (2, 2))
1011
c = cos(R.theta / 2)
1112
s = -im * sin(R.theta / 2)
1213
M[1, 1] = c
@@ -19,7 +20,7 @@ end
1920
function YaoBlocks.mat(
2021
::Type{T}, R::RotationGate{D,Reactant.TracedRNumber{S},<:YGate}
2122
) where {D,T,S}
22-
M = Reactant.broadcast_to_size(zero(T), (2, 2))
23+
M = broadcast_to_size(zero(T), (2, 2))
2324
c = cos(R.theta / 2)
2425
s = sin(R.theta / 2)
2526
M[1, 1] = c
@@ -32,7 +33,7 @@ end
3233
function YaoBlocks.mat(
3334
::Type{T}, R::RotationGate{D,Reactant.TracedRNumber{S},<:ZGate}
3435
) where {D,T,S}
35-
M = Reactant.broadcast_to_size(zero(T), (2, 2))
36+
M = broadcast_to_size(zero(T), (2, 2))
3637
x = exp(im * R.theta / 2)
3738
M[1, 1] = conj(x)
3839
M[2, 2] = x

lib/ReactantCore/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "ReactantCore"
22
uuid = "a3311ec8-5e00-46d5-b541-4f83e724a433"
33
authors = ["William Moses <[email protected]>", "Valentin Churavy <[email protected]>", "Sergio Sánchez Ramírez <[email protected]>", "Paul Berg <[email protected]>", "Avik Pal <[email protected]>"]
4-
version = "0.1.2"
4+
version = "0.1.3"
55

66
[deps]
77
ExpressionExplorer = "21656369-7473-754a-2065-74616d696c43"

lib/ReactantCore/src/ReactantCore.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -153,15 +153,17 @@ function trace_for(mod, expr)
153153

154154
all_syms = Expr(:tuple, counter, external_syms...)
155155
args_init = Expr(
156-
:tuple, :(Reactant.promote_to(Reactant.TracedRNumber{Int}, 0)), external_syms...
156+
:tuple,
157+
:(Reactant.TracedUtils.promote_to(Reactant.TracedRNumber{Int}, 0)),
158+
external_syms...,
157159
)
158160

159161
reactant_code_block = quote
160162
let args = $(args_init)
161163
cond_fn =
162164
$(all_syms) -> begin
163165
local num_iters = div($limit - $start, $step, RoundDown)
164-
local num_iters = Reactant.promote_to(
166+
local num_iters = Reactant.TracedUtils.promote_to(
165167
Reactant.TracedRNumber{Int64}, num_iters
166168
)
167169
$counter < num_iters + 1

src/Compiler.jl

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -292,7 +292,7 @@ function compile_mlir!(mod, f, args; optimize::Union{Bool,Symbol}=true)
292292
func2, traced_result, result, seen_args, ret, linear_args, in_tys,
293293
linear_results = MLIR.IR.mmodule!(mod) do
294294
MLIR.IR.block!(MLIR.IR.body(mod)) do
295-
return Reactant.make_mlir_fn(f, args, (), "main", true)
295+
return Reactant.TracedUtils.make_mlir_fn(f, args, (), "main", true)
296296
end
297297
end
298298

@@ -779,6 +779,13 @@ function compile(f, args; client=nothing, optimize=true, sync=false)
779779
return register_thunk(fname, body)
780780
end
781781

782+
# Compiling within a compile should return simply the original function
783+
Reactant.@reactant_override function Reactant.Compiler.compile(
784+
f, args; client=nothing, optimize=true, sync=false
785+
)
786+
return f
787+
end
788+
782789
# inspired by RuntimeGeneratedFunction.jl
783790
const __thunk_body_cache = Dict{Symbol,Expr}()
784791

src/ConcreteRArray.jl

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ end
9999
function Base.convert(
100100
::Type{T}, X::WrappedConcreteRArray{ElType,N}
101101
) where {T<:Array,ElType,N}
102-
fn = compile(materialize_traced_array, (X,))
102+
fn = compile(TracedUtils.materialize_traced_array, (X,))
103103
return convert(Array, fn(X))
104104
end
105105
Base.Array(x::AnyConcreteRArray) = convert(Array, x)
@@ -345,3 +345,11 @@ end
345345

346346
buffer_on_cpu(::Any) = true
347347
buffer_on_cpu(x::ConcreteRArray) = XLA.BufferOnCPU(x.data.buffer)
348+
349+
function Ops.constant(x::ConcreteRArray; kwargs...)
350+
return Ops.constant(Base.convert(Array, x); kwargs...)
351+
end
352+
353+
function Ops.constant(x::ConcreteRNumber{T}; kwargs...) where {T}
354+
return Ops.constant(Base.convert(T, x); kwargs...)
355+
end

0 commit comments

Comments
 (0)