Skip to content

Commit 43a4d83

Browse files
apaszkejax authors
authored andcommitted
Update the TPU dialect binding extension to follow MLIR guidelines
The way MLIR dialects are allowed to be extended in Python has recently changed (in llvm/llvm-project#68853), so we have to update our bindings. PiperOrigin-RevId: 575796552
1 parent 20e5838 commit 43a4d83

File tree

4 files changed

+32
-41
lines changed

4 files changed

+32
-41
lines changed

jaxlib/mosaic/python/BUILD

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,6 @@ py_library(
4848
name = "tpu_dialect",
4949
srcs = [
5050
"_tpu_gen.py",
51-
"_tpu_ops_ext.py",
5251
"tpu.py",
5352
],
5453
visibility = ["//visibility:public"],

jaxlib/mosaic/python/_tpu_ops_ext.py

Lines changed: 0 additions & 39 deletions
This file was deleted.

jaxlib/mosaic/python/tpu.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,38 @@
1717
# flake8: noqa: F401
1818
# flake8: noqa: F403
1919

20+
2021
# pylint: disable=g-bad-import-order
2122
from ._tpu_gen import * # pylint: disable=wildcard-import
23+
from ._tpu_gen import _Dialect
2224
from jaxlib.mlir._mlir_libs._tpu_ext import * # pylint: disable=wildcard-import
25+
try:
26+
from jaxlib.mlir.dialects._ods_common import _cext
27+
except ImportError:
28+
from mlir.dialects._ods_common import _cext
29+
30+
31+
@_cext.register_operation(_Dialect, replace=True)
32+
class TraceOp(TraceOp):
33+
"""An extension to the automatically generated TraceOp bindings."""
34+
35+
def __init__(self, results, message, level, *, loc=None, ip=None):
36+
super().__init__(results, message, level, loc=loc, ip=ip)
37+
self.regions[0].blocks.append(*[]) # Append the block.
38+
39+
@property
40+
def body(self):
41+
return self.regions[0].blocks[0]
42+
43+
44+
@_cext.register_operation(_Dialect, replace=True)
45+
class RegionOp(RegionOp):
46+
"""An extension to the automatically generated RegionOp bindings."""
47+
48+
def __init__(self, *, loc=None, ip=None):
49+
super().__init__([], loc=loc, ip=ip)
50+
self.regions[0].blocks.append() # Append the block.
51+
52+
@property
53+
def body(self):
54+
return self.regions[0].blocks[0]

jaxlib/tools/build_wheel.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -254,7 +254,6 @@ def prepare_wheel(sources_path: pathlib.Path, *, cpu, include_gpu_plugin_extensi
254254
"__main__/jaxlib/mosaic/python/apply_vector_layout.py",
255255
"__main__/jaxlib/mosaic/python/infer_memref_layout.py",
256256
"__main__/jaxlib/mosaic/python/tpu.py",
257-
"__main__/jaxlib/mosaic/python/_tpu_ops_ext.py",
258257
],
259258
)
260259
# TODO (sharadmv,skyewm): can we avoid patching this file?

0 commit comments

Comments
 (0)