From 9799ed0922a3b2eb728d9bc9eae682deb8e7364a Mon Sep 17 00:00:00 2001 From: Sungjoon Shon Date: Mon, 11 Aug 2025 21:06:44 +0000 Subject: [PATCH] feat: add support for custom compile options in torch_xla.compile and PJRT backend MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This change introduces the ability to pass custom compile options from Python down to the PJRT backend, allowing users to fine-tune XLA compilation behavior without modifying core code. Key changes: * Python API * Added custom_compile_options parameter to torch_xla.compile for passing compile-time options as a dict (supports bool, float, int, and str values). * Added torch_xla.set_custom_compile_options() utility for setting compile options globally. * Added internal binding _XLAC._set_custom_compile_options(). * C++ Runtime * Added SetCustomCompileOptions() virtual method to ComputationClient and implemented it in PjRtComputationClient. * PjRtComputationClient now stores custom_compile_options_ and injects them into xla::CompileOptions.env_option_overrides during compilation. * Options are stringified before being passed to XLA for compatibility. Motivation:
This enables advanced users to pass through backend-specific tuning flags (e.g., enabling experimental optimizations, toggling partitioning strategies) without hardcoding them, improving flexibility for research and debugging workflows. --- torch_xla/csrc/init_python_bindings.cpp | 13 +++++++++++++ torch_xla/csrc/runtime/computation_client.h | 8 ++++++++ .../csrc/runtime/ifrt_computation_client.h | 5 +++++ .../csrc/runtime/pjrt_computation_client.cpp | 12 ++++++++++++ .../csrc/runtime/pjrt_computation_client.h | 5 +++++ torch_xla/torch_xla.py | 19 +++++++++++++++++++ 6 files changed, 62 insertions(+) diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index a52ecc8124e7..66657e4a2549 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -3306,6 +3306,19 @@ void InitXlaModuleBindings(py::module m) { XLA_ERROR() << "Could not get the buffer pointer for XLATensor " "without a data handle or an IR."; }) + .def("_set_custom_compile_options", + [](const py::dict& compile_options) { + std::unordered_map options; + for (const auto& item : compile_options) { + // Keys must be strings; values are stringified. + const std::string key = py::str(item.first); + options[key] = py::str(item.second); + } + XLA_ASSIGN_OR_THROW( + runtime::ComputationClient * absl_nonnull client, + runtime::GetComputationClient()); + client->SetCustomCompileOptions(options); + }) .def( // from an XLA tensor to a PyCapsule. // When consuming the PyCapsule, we should synchronize diff --git a/torch_xla/csrc/runtime/computation_client.h b/torch_xla/csrc/runtime/computation_client.h index 79ff199eb2ff..364641a88692 100644 --- a/torch_xla/csrc/runtime/computation_client.h +++ b/torch_xla/csrc/runtime/computation_client.h @@ -446,6 +446,14 @@ class ComputationClient { // after the last ':' character of the device string. static int64_t GetDeviceOrdinal(const std::string& device); + // Sets XLA compile option overrides used by the backend compiler. + // - The map keys are XLA compiler flag names (env option override keys). + // - The values are stringified flag values. + // - Calling this method **overwrites** any previously set options. + // (Pass an empty map to clear.) + virtual void SetCustomCompileOptions( + const std::unordered_map& options) = 0; + protected: static constexpr auto spmd_device_str = "SPMD:0"; diff --git a/torch_xla/csrc/runtime/ifrt_computation_client.h b/torch_xla/csrc/runtime/ifrt_computation_client.h index 8b45922c397f..6e7875457105 100644 --- a/torch_xla/csrc/runtime/ifrt_computation_client.h +++ b/torch_xla/csrc/runtime/ifrt_computation_client.h @@ -172,6 +172,11 @@ class IfrtComputationClient : public ComputationClient { XLA_ERROR() << __FUNCTION__ << " not implemented"; } + void SetCustomCompileOptions( + const std::unordered_map& options) override { + XLA_ERROR() << __FUNCTION__ << " not implemented"; + } + // Creates a new instance of IfrtComputationClient and initializes it. static absl::StatusOr> Create(); diff --git a/torch_xla/csrc/runtime/pjrt_computation_client.cpp b/torch_xla/csrc/runtime/pjrt_computation_client.cpp index 280b50964d82..908af49cdb92 100644 --- a/torch_xla/csrc/runtime/pjrt_computation_client.cpp +++ b/torch_xla/csrc/runtime/pjrt_computation_client.cpp @@ -554,6 +554,9 @@ std::vector PjRtComputationClient::Compile( for (auto& instance : instances) { xla::CompileOptions compile_options; + for (const auto& [name, value] : custom_compile_options_) { + compile_options.env_option_overrides.push_back({name, value}); + } if (enable_cm_in_mp) { compile_options.executable_build_options.set_use_spmd_partitioning(true); compile_options.env_option_overrides.push_back( @@ -561,6 +564,7 @@ std::vector PjRtComputationClient::Compile( compile_options.env_option_overrides.push_back( {"xla_tpu_decompose_einsum_reduce_scatter", true}); } + if (instance.is_sharded) { // TODO(yeounoh) multi-host, multi-slice configurations compile_options.executable_build_options.set_use_spmd_partitioning(true); @@ -1052,5 +1056,13 @@ void PjRtComputationClient::OnReadyCallback( [callback](absl::Status unused) { callback(); }); } +void PjRtComputationClient::SetCustomCompileOptions( + const std::unordered_map& options) { + custom_compile_options_.clear(); + for (const auto& [key, value] : options) { + custom_compile_options_[key] = value; + } +} + } // namespace runtime } // namespace torch_xla diff --git a/torch_xla/csrc/runtime/pjrt_computation_client.h b/torch_xla/csrc/runtime/pjrt_computation_client.h index d550f1cce0cb..fc516a7042c9 100644 --- a/torch_xla/csrc/runtime/pjrt_computation_client.h +++ b/torch_xla/csrc/runtime/pjrt_computation_client.h @@ -165,6 +165,10 @@ class PjRtComputationClient : public ComputationClient { void OnReadyCallback(DataPtr data, const std::function& callback) override; + // See base class for semantics. This call overwrites previously set options. + void SetCustomCompileOptions( + const std::unordered_map& options) override; + // Creates a new instance of PjRtComputationClient and initializes it. static absl::StatusOr> Create(); @@ -197,6 +201,7 @@ class PjRtComputationClient : public ComputationClient { // If not nullptr, invoke this instead of the actual XLA compilation. Used // only for testing. std::function fake_xla_compile_ = nullptr; + std::unordered_map custom_compile_options_; xla::PjRtDevice* StringToPjRtDevice(const std::string& device); diff --git a/torch_xla/torch_xla.py b/torch_xla/torch_xla.py index 9062d6a9ef21..76f209f008bf 100644 --- a/torch_xla/torch_xla.py +++ b/torch_xla/torch_xla.py @@ -116,6 +116,7 @@ def compile( full_graph: Optional[bool] = False, name: Optional[str] = None, max_different_graphs: Optional[int] = None, + custom_compile_options: Optional[dict[str, Any]] = None, ): """ Optimizes given model/function using torch_xla's LazyTensor tracing mode. @@ -136,6 +137,11 @@ def compile( max_different_graphs (Optional[int]): number of different traced graphs of the given model/function that we are allowed to have. An error will be raised in case this limit is exceeded. + custom_compile_options (Optional[dict[str, Any]]): XLA compiler flag overrides. + Keys are XLA compiler flag names (forwarded to xla::CompileOptions.env_option_overrides), + and values may be bool, int, float, or str (internally stringified). + - {} (empty dict): clear previously set options. + - None (default): do not change previously set options (no-op). Example:: @@ -215,6 +221,8 @@ def _compile(): torch_xla._XLAC._set_use_eager_mode(saved_eager_mode_status) torch_xla._XLAC._set_current_graph_name(saved_current_graph_name) + if custom_compile_options is not None: + torch_xla._XLAC._set_custom_compile_options(custom_compile_options) return _compile() if f is None else _compile()(f) @@ -264,3 +272,14 @@ def launch( fn(xu.getenv_as(xenv.LOCAL_RANK, int), *args) else: xmp.spawn(fn, args=args, nprocs=nprocs, start_method=start_method) + + +def set_custom_compile_options(options: dict[str, Any]) -> None: + """Set XLA **compiler flag overrides** (env option overrides) for compilation. + + Args: + options: Dict mapping XLA flag names to values. Values may be bool/float/int/str; + they will be stringified before being passed to XLA. + Pass an empty dict `{}` to clear previously set options. + """ + torch_xla._XLAC._set_custom_compile_options(options)