Skip to content

Commit 364901f

Browse files
giordanoavik-pal
andauthored
Update Reactant_jll and bump version number (#984)
* Update Reactant_jll and bump version number * fix: apply transpose in sharding_constraint --------- Co-authored-by: Avik Pal <[email protected]>
1 parent 495a75c commit 364901f

File tree

3 files changed

+6
-5
lines changed

3 files changed

+6
-5
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "Reactant"
22
uuid = "3c362404-f566-11ee-1572-e11a4b42c853"
33
authors = ["William Moses <[email protected]>", "Valentin Churavy <[email protected]>", "Sergio Sánchez Ramírez <[email protected]>", "Paul Berg <[email protected]>", "Avik Pal <[email protected]>", "Mosè Giordano <[email protected]>"]
4-
version = "0.2.46"
4+
version = "0.2.47"
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
@@ -86,7 +86,7 @@ PythonCall = "0.9"
8686
Random = "1.10"
8787
Random123 = "1.7"
8888
ReactantCore = "0.1.6"
89-
Reactant_jll = "0.0.94"
89+
Reactant_jll = "0.0.95"
9090
Scratch = "1.2"
9191
Sockets = "1.10"
9292
SpecialFunctions = "2.4"

src/Ops.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2388,7 +2388,7 @@ Produces a [`Reactant.MLIR.Dialects.sdy.sharding_constraint`](@ref) operation wi
23882388
haskey(cache, sharding.mesh) || Ops.mesh(sharding.mesh; location)
23892389
(; sym_name, mesh_attr) = cache[sharding.mesh]
23902390
tensor_sharding_attr = Reactant.Sharding.get_shardy_tensor_sharding_attribute(
2391-
sharding, MLIR.IR.context(), sym_name, mesh_attr; do_transpose=true
2391+
sharding, MLIR.IR.context(), sym_name, mesh_attr; do_transpose=false
23922392
)
23932393
resharded_value = MLIR.IR.result(
23942394
MLIR.Dialects.sdy.sharding_constraint(

test/sharding.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -209,12 +209,13 @@ end
209209
@test contains(repr(hlo), "sharding_constraint")
210210
hlo = @code_hlo shardy_passes = :to_mhlo_shardings fn_with_constraint(x_ra)
211211
@test !contains(repr(hlo), "sharding_constraint")
212-
@test length(collect(eachmatch(r"mhlo.sharding", repr(hlo)))) == 3
212+
@test length(collect(eachmatch(r"mhlo.sharding", repr(hlo)))) == 5
213213

214214
z = Reactant.to_rarray(x; sharding=constraint)
215215
res = @jit fn_with_constraint(x_ra)
216216

217217
@test x .+ x Array(res)
218+
218219
@test string(z.sharding.sharding.hlo_sharding) ==
219220
string(res.sharding.sharding.hlo_sharding)
220221
@test string(res.sharding.sharding.hlo_sharding) !=
@@ -229,7 +230,7 @@ end
229230
x_ra_no_sharding
230231
)
231232
@test !contains(repr(hlo), "sharding_constraint")
232-
@test length(collect(eachmatch(r"mhlo.sharding", repr(hlo)))) == 3
233+
@test length(collect(eachmatch(r"mhlo.sharding", repr(hlo)))) == 5
233234

234235
res = @jit fn_with_constraint(x_ra_no_sharding)
235236
@test x .+ x Array(res)

0 commit comments

Comments
 (0)