Skip to content

Commit f600f07

Browse files
authored
Also dump MLIR module of failed compilation inside IFRT and PJRT (#1021)
1 parent c78c5a0 commit f600f07

File tree

3 files changed

+63
-41
lines changed

3 files changed

+63
-41
lines changed

src/mlir/IR/Pass.jl

Lines changed: 39 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -66,40 +66,58 @@ end
6666

6767
const DUMP_MLIR_DIR = Ref{Union{Nothing,String}}(nothing)
6868

69-
"""
70-
run!(passManager, module)
71-
72-
Run the provided `passManager` on the given `module`.
73-
"""
74-
function run!(pm::PassManager, mod::Module)
75-
status = LogicalResult(@static if isdefined(API, :mlirPassManagerRunOnOp)
76-
API.mlirPassManagerRunOnOp(pm, Operation(mod))
77-
else
78-
API.mlirPassManagerRun(pm, mod)
79-
end)
80-
if isfailure(status)
69+
# Utilities for dumping to a file the module of a failed compilation, useful for
70+
# debugging purposes.
71+
function compilation_failed_dump_mlir(mod::Module, pm::Union{Nothing,PassManager}=nothing)
72+
try
8173
# If `DUMP_MLIR_DIR` is `nothing`, create a persistent new temp
8274
# directory, otherwise use the provided path.
8375
dir = if isnothing(DUMP_MLIR_DIR[])
76+
mkpath(tempdir())
8477
mktempdir(; prefix="reactant_", cleanup=false)
8578
else
8679
DUMP_MLIR_DIR[]
8780
end
88-
try
89-
# Make sure the directory exists
90-
mkpath(dir)
91-
path = tempname(dir; cleanup=false) * ".mlir"
92-
open(path, "w") do io
81+
# Make sure the directory exists
82+
mkpath(dir)
83+
path = tempname(dir; cleanup=false) * ".mlir"
84+
open(path, "w") do io
85+
if !isnothing(pm)
9386
println(io, "// Pass pipeline:")
9487
print(io, "// ")
9588
print_pass_pipeline(io, OpPassManager(pm))
9689
println(io)
97-
show(IOContext(io, :debug => true), mod)
9890
end
99-
@error "Dumped module to " * path
100-
catch err
101-
@error "Couldn't save MLIR module" exception = err
91+
show(IOContext(io, :debug => true), mod)
10292
end
93+
@error "Compilation failed, MLIR module written to $(path)"
94+
catch err
95+
@error "Couldn't save MLIR module" exception = err
96+
end
97+
end
98+
99+
function try_compile_dump_mlir(f, mod::Module, pm=nothing)
100+
try
101+
f()
102+
catch
103+
compilation_failed_dump_mlir(mod, pm)
104+
rethrow()
105+
end
106+
end
107+
108+
"""
109+
run!(passManager, module)
110+
111+
Run the provided `passManager` on the given `module`.
112+
"""
113+
function run!(pm::PassManager, mod::Module)
114+
status = LogicalResult(@static if isdefined(API, :mlirPassManagerRunOnOp)
115+
API.mlirPassManagerRunOnOp(pm, Operation(mod))
116+
else
117+
API.mlirPassManagerRun(pm, mod)
118+
end)
119+
if isfailure(status)
120+
compilation_failed_dump_mlir(mod, pm)
103121
throw("failed to run pass manager on module")
104122
end
105123
return mod

src/xla/IFRT/LoadedExecutable.jl

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -85,16 +85,18 @@ function XLA.compile(
8585
)
8686
device_id = is_sharded ? Int64(-1) : Int64(XLA.device_ordinal(device))
8787
GC.@preserve client mod begin
88-
exec = @ccall MLIR.API.mlir_c.ifrt_compile(
89-
client.client::Ptr{Cvoid},
90-
mod.module_::MLIR.API.MlirModule,
91-
device_id::Clong,
92-
is_sharded::Bool,
93-
global_device_ids::Ptr{Clong},
94-
length(global_device_ids)::Clong,
95-
XLA.CUDA_DATA_DIR[]::Cstring,
96-
use_shardy_partitioner::Bool,
97-
)::Ptr{Cvoid}
88+
exec = MLIR.IR.try_compile_dump_mlir(mod) do
89+
@ccall MLIR.API.mlir_c.ifrt_compile(
90+
client.client::Ptr{Cvoid},
91+
mod.module_::MLIR.API.MlirModule,
92+
device_id::Clong,
93+
is_sharded::Bool,
94+
global_device_ids::Ptr{Clong},
95+
length(global_device_ids)::Clong,
96+
XLA.CUDA_DATA_DIR[]::Cstring,
97+
use_shardy_partitioner::Bool,
98+
)::Ptr{Cvoid}
99+
end
98100
end
99101
return LoadedExecutable(
100102
exec, num_outputs, num_parameters, is_sharded, num_replicas, num_partitions

src/xla/PJRT/LoadedExecutable.jl

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -79,16 +79,18 @@ function XLA.compile(
7979
)
8080
device_id = is_sharded ? Int64(-1) : Int64(XLA.device_ordinal(device))
8181
GC.@preserve client mod begin
82-
exec = @ccall MLIR.API.mlir_c.ClientCompile(
83-
client.client::Ptr{Cvoid},
84-
mod.module_::MLIR.API.MlirModule,
85-
device_id::Clong,
86-
is_sharded::Bool,
87-
global_device_ids::Ptr{Clong},
88-
length(global_device_ids)::Clong,
89-
XLA.CUDA_DATA_DIR[]::Cstring,
90-
use_shardy_partitioner::Bool,
91-
)::Ptr{Cvoid}
82+
exec = MLIR.IR.try_compile_dump_mlir(mod) do
83+
@ccall MLIR.API.mlir_c.ClientCompile(
84+
client.client::Ptr{Cvoid},
85+
mod.module_::MLIR.API.MlirModule,
86+
device_id::Clong,
87+
is_sharded::Bool,
88+
global_device_ids::Ptr{Clong},
89+
length(global_device_ids)::Clong,
90+
XLA.CUDA_DATA_DIR[]::Cstring,
91+
use_shardy_partitioner::Bool,
92+
)::Ptr{Cvoid}
93+
end
9294
end
9395
return LoadedExecutable(
9496
exec, num_outputs, num_parameters, is_sharded, num_replicas, num_partitions

0 commit comments

Comments
 (0)