Skip to content

Commit 63d407a

Browse files
PTX fma and other flags (#585)
* PTX fma and other flags * we can keep cuda debug info now, that's cool * version bump * Update XLA.jl * Update src/XLA.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
1 parent cca721d commit 63d407a

File tree

6 files changed

+22
-6
lines changed

6 files changed

+22
-6
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "Reactant"
22
uuid = "3c362404-f566-11ee-1572-e11a4b42c853"
33
authors = ["William Moses <[email protected]>", "Valentin Churavy <[email protected]>", "Sergio Sánchez Ramírez <[email protected]>", "Paul Berg <[email protected]>", "Avik Pal <[email protected]>"]
4-
version = "0.2.20"
4+
version = "0.2.21"
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
@@ -70,7 +70,7 @@ PythonCall = "0.9"
7070
Random = "1.10"
7171
Random123 = "1.7"
7272
ReactantCore = "0.1.4"
73-
Reactant_jll = "0.0.45"
73+
Reactant_jll = "0.0.46"
7474
Scratch = "1.2"
7575
Sockets = "1.10"
7676
SpecialFunctions = "2.4"

deps/ReactantExtra/API.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -446,6 +446,12 @@ std::vector<int64_t> col_major(int64_t dim) {
446446
return minor_to_major;
447447
}
448448

449+
extern "C" void ReactantLLVMParseCommandLineOptions(int argc, const char *const *argv,
450+
const char *Overview) {
451+
llvm::cl::ParseCommandLineOptions(argc, argv, StringRef(Overview),
452+
&llvm::nulls());
453+
}
454+
449455
std::vector<int64_t> row_major(int64_t dim) {
450456
std::vector<int64_t> minor_to_major;
451457
for (int i = 0; i < dim; i++) {

deps/ReactantExtra/BUILD

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -450,7 +450,8 @@ cc_library(
450450
"-Wl,-exported_symbol,_ProfilerActivityStart",
451451
"-Wl,-exported_symbol,_ProfilerActivityEnd",
452452
"-Wl,-exported_symbol,_ReactantFuncSetArgAttr",
453-
"-Wl,-exported_symbol,_ReactantCudaDriverGetVersion"
453+
"-Wl,-exported_symbol,_ReactantCudaDriverGetVersion",
454+
"-Wl,-exported_symbol,_ReactantLLVMParseCommandLineOptions"
454455
]}),
455456
deps = [
456457
"@enzyme//:EnzymeMLIR",

deps/ReactantExtra/WORKSPACE

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ http_archive(
136136
)
137137

138138
# load("@jax//third_party/xla:workspace.bzl", "XLA_COMMIT", "XLA_SHA256")
139-
XLA_COMMIT = "281c11225c4a0bb7b710a290610a06d71194febd"
139+
XLA_COMMIT = "e0c92850a41cf5208744d8a919b969fa3506863c"
140140
XLA_SHA256 = ""
141141

142142
http_archive(

ext/ReactantCUDAExt.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -400,11 +400,10 @@ function compile(job)
400400
if !isempty(errors)
401401
throw(GPUCompiler.InvalidIRError(job, errors))
402402
end
403-
LLVM.strip_debuginfo!(mod)
403+
# LLVM.strip_debuginfo!(mod)
404404
modstr = string(mod)
405405
# This is a bit weird since we're taking a module from julia's llvm into reactant's llvm version
406406
# it is probably safer to reparse a string using the right llvm module api, so we will do that.
407-
408407
mmod = MLIR.IR.Module(
409408
@ccall MLIR.API.mlir_c.ConvertLLVMStrToMLIR(
410409
modstr::Cstring, MLIR.IR.context()::MLIR.API.MlirContext

src/XLA.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,13 @@ module XLA
22

33
import ...MLIR
44

5+
function LLVMclopts(opts...)
6+
args = ["", opts...]
7+
@ccall MLIR.API.mlir_c.ReactantLLVMParseCommandLineOptions(
8+
length(args)::Cint, args::Ptr{Cstring}, C_NULL::Ptr{Cvoid}
9+
)::Cvoid
10+
end
11+
512
mutable struct Client
613
client::Ptr{Cvoid}
714

@@ -50,6 +57,7 @@ function CPUClient(asynchronous=false, node_id=0, num_nodes=1; checkcount=true)
5057
end
5158
f = Libdl.dlsym(Reactant_jll.libReactantExtra_handle, "MakeCPUClient")
5259
client = ccall(f, Ptr{Cvoid}, (UInt, Cint, Cint), asynchronous, node_id, num_nodes)
60+
LLVMclopts("-nvptx-fma-level=1")
5361
#client = @ccall MLIR.API.mlir_c.MakeCPUClient(asynchronous::UInt8, node_id::Cint, num_nodes::Cint)::Ptr{Cvoid}
5462
return Client(client)
5563
end
@@ -73,6 +81,7 @@ function GPUClient(node_id=0, num_nodes=1, platform="gpu")
7381
if client == C_NULL
7482
throw(AssertionError(unsafe_string(refstr[])))
7583
end
84+
LLVMclopts("-nvptx-fma-level=1")
7685
return Client(client)
7786
end
7887

@@ -83,6 +92,7 @@ function TPUClient(tpu_path::String)
8392
if client == C_NULL
8493
throw(AssertionError(unsafe_string(refstr[])))
8594
end
95+
LLVMclopts("-nvptx-fma-level=1")
8696
return Client(client)
8797
end
8898

0 commit comments

Comments
 (0)