|
1 | 1 | import functools |
| 2 | +import os |
2 | 3 | import threading |
| 4 | + |
3 | 5 | import torch_xla |
4 | 6 | import torch_xla.core.xla_model as xm |
5 | 7 |
|
@@ -183,3 +185,65 @@ def wrapper_trace_me(*args, **kwargs): |
183 | 185 | return wrapper_trace_me |
184 | 186 |
|
185 | 187 | return decorator_trace_me |
| 188 | + |
| 189 | + |
| 190 | +# The profiler implementation is based on JAX implementation |
| 191 | +# https://github.com/jax-ml/jax/blob/main/jax/_src/profiler.py |
| 192 | +class _ProfileState: |
| 193 | + |
| 194 | + def __init__(self): |
| 195 | + self.profile_session = None |
| 196 | + self.log_dir = None |
| 197 | + self.create_perfetto_link = False |
| 198 | + self.create_perfetto_trace = False |
| 199 | + self.lock = threading.Lock() |
| 200 | + |
| 201 | + def reset(self): |
| 202 | + _profile_state.profile_session = None |
| 203 | + _profile_state.create_perfetto_link = False |
| 204 | + _profile_state.create_perfetto_trace = False |
| 205 | + _profile_state.log_dir = None |
| 206 | + |
| 207 | + |
| 208 | +_profile_state = _ProfileState() |
| 209 | + |
| 210 | + |
| 211 | +def start_trace(log_dir: os.PathLike | str) -> None: |
| 212 | + """Starts a profiler trace. |
| 213 | +
|
| 214 | + The trace will capture CPU, GPU, and/or TPU activity, including Python |
| 215 | + functions and PyTorch/XLA on-device operations. Use :func:`stop_trace` to end |
| 216 | + the trace and save the results to ``log_dir``. |
| 217 | +
|
| 218 | + The resulting trace can be viewed with TensorBoard. Note that TensorBoard |
| 219 | + doesn't need to be running when collecting the trace. |
| 220 | +
|
| 221 | + Only one trace may be collected at a time. A RuntimeError will be raised if |
| 222 | + :func:`start_trace` is called while another trace is running. |
| 223 | +
|
| 224 | + Args: |
| 225 | + log_dir: The directory to save the profiler trace to (usually the |
| 226 | + TensorBoard log directory). |
| 227 | + """ |
| 228 | + with _profile_state.lock: |
| 229 | + if _profile_state.profile_session is not None: |
| 230 | + raise RuntimeError("Profile has already been started. " |
| 231 | + "Only one profile may be run at a time.") |
| 232 | + |
| 233 | + _profile_state.profile_session = torch_xla._XLAC.profiler.TslProfilerSessionWrapper( |
| 234 | + ) |
| 235 | + _profile_state.log_dir = str(log_dir) |
| 236 | + |
| 237 | + |
| 238 | +def stop_trace() -> None: |
| 239 | + """Stops the currently-running profiler trace. |
| 240 | +
|
| 241 | + The trace will be saved to the ``log_dir`` passed to the corresponding |
| 242 | + :func:`start_trace` call. Raises a RuntimeError if a trace hasn't been started. |
| 243 | + """ |
| 244 | + with _profile_state.lock: |
| 245 | + if _profile_state.profile_session is None: |
| 246 | + raise RuntimeError("No profile started") |
| 247 | + sess = _profile_state.profile_session |
| 248 | + sess.export(sess.stop(), str(_profile_state.log_dir)) |
| 249 | + _profile_state.reset() |
0 commit comments