-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Closed
Labels
Milestone
Description
Bug description
Running with lightning 2.0 and pytorch 2.0.
Hi, I'm torch.compiling my lightning model - while I do see a non-negligible (~ 20%!) speedup in training, torch.dynamo errors out in the validation loop when it encounters a self.logger.experiment.log call to Weights&Biases (see below)
Is there anything i can do to stop torch.dynamo from choking on this bit of code? Maybe some decorator that tells JIT to ignore the offending self.logger.experiment.log call? Many thanks!
How to reproduce the bug
# torch.compile call
# type(model) == LightningModel
model = torch.compile(model, mode="default", backend="inductor", fullgraph=False)
# the model should log an image to wandb through a self.logger.experiment.log callError messages and logs
[2023-03-28 10:14:51,227] torch._dynamo.convert_frame: [WARNING] torch._dynamo hit config.cache_size_limit (64)
function: '<graph break in draw_text>' (/python3.9/site-packages/matplotlib/backends/backend_agg.py:212)
reasons: d == 2.0
to diagnose recompilation issues, see https://pytorch.org/docs/master/dynamo/troubleshooting.html.
[2023-03-28 10:14:51,298] torch._dynamo.convert_frame: [WARNING] torch._dynamo hit config.cache_size_limit (64)
function: '<graph break in draw_text>' (python3.9/site-packages/matplotlib/backends/backend_agg.py:212)
reasons: d == 2.0
to diagnose recompilation issues, see https://pytorch.org/docs/master/dynamo/troubleshooting.html.
[2023-03-28 10:14:51,621] torch._dynamo.convert_frame: [WARNING] torch._dynamo hit config.cache_size_limit (64)
function: '<graph break in draw_text>' (python3.9/site-packages/matplotlib/backends/backend_agg.py:212)
reasons: d == 2.0
to diagnose recompilation issues, see https://pytorch.org/docs/master/dynamo/troubleshooting.html.
[2023-03-28 10:14:51,968] torch._dynamo.convert_frame: [WARNING] torch._dynamo hit config.cache_size_limit (64)
function: '<graph break in draw_text>' (python3.9/site-packages/matplotlib/backends/backend_agg.py:212)
reasons: d == 2.0
to diagnose recompilation issues, see https://pytorch.org/docs/master/dynamo/troubleshooting.html.
[2023-03-28 10:14:52,445] torch._dynamo.convert_frame: [WARNING] torch._dynamo hit config.cache_size_limit (64)
function: '<graph break in draw_text>' (python3.9/site-packages/matplotlib/backends/backend_agg.py:204)
reasons: s == 't850 target'
to diagnose recompilation issues, see https://pytorch.org/docs/master/dynamo/troubleshooting.html.
Traceback (most recent call last):
File "python3.9/site-packages/torch/_dynamo/convert_frame.py", line 324, in _compile
out_code = transform_code_object(code, transform)
File "python3.9/site-packages/torch/_dynamo/bytecode_transformation.py", line 445, in transform_code_object
transformations(instructions, code_options)
File "python3.9/site-packages/torch/_dynamo/convert_frame.py", line 311, in transform
tracer.run()
File "python3.9/site-packages/torch/_dynamo/symbolic_convert.py", line 1726, in run
super().run()
File "python3.9/site-packages/torch/_dynamo/symbolic_convert.py", line 576, in run
and self.step()
File "python3.9/site-packages/torch/_dynamo/symbolic_convert.py", line 540, in step
getattr(self, inst.opname)(inst)
File "python3.9/site-packages/torch/_dynamo/symbolic_convert.py", line 1030, in LOAD_ATTR
result = BuiltinVariable(getattr).call_function(
File "python3.9/site-packages/torch/_dynamo/variables/builtin.py", line 566, in call_function
result = handler(tx, *args, **kwargs)
File "python3.9/site-packages/torch/_dynamo/variables/builtin.py", line 958, in call_getattr
obj.var_getattr(tx, name).clone(source=source).add_options(options)
File "python3.9/site-packages/torch/_dynamo/variables/user_defined.py", line 319, in var_getattr
return variables.UserMethodVariable(
File "python3.9/site-packages/torch/_dynamo/variables/functions.py", line 291, in call_function
return super().call_function(tx, args, kwargs)
File "python3.9/site-packages/torch/_dynamo/variables/functions.py", line 259, in call_function
return super().call_function(tx, args, kwargs)
File "python3.9/site-packages/torch/_dynamo/variables/functions.py", line 92, in call_function
return tx.inline_user_function_return(
File "/python3.9/site-packages/torch/_dynamo/symbolic_convert.py", line 510, in inline_user_function_return
result = InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
File "python3.9/site-packages/torch/_dynamo/symbolic_convert.py", line 1806, in inline_call
Traceback (most recent call last):
File "python3.9/site-packages/torch/_dynamo/convert_frame.py", line 324, in _compile
return cls.inline_call_(parent, func, args, kwargs)
File "python3.9/site-packages/torch/_dynamo/symbolic_convert.py", line 1862, in inline_call_
out_code = transform_code_object(code, transform)
File "python3.9/site-packages/torch/_dynamo/bytecode_transformation.py", line 445, in transform_code_object
tracer.run()
File "python3.9/site-packages/torch/_dynamo/symbolic_convert.py", line 576, in run
transformations(instructions, code_options)
File "python3.9/site-packages/torch/_dynamo/convert_frame.py", line 311, in transform
and self.step()
File "python3.9/site-packages/torch/_dynamo/symbolic_convert.py", line 540, in step
tracer.run()
File "python3.9/site-packages/torch/_dynamo/symbolic_convert.py", line 1726, in run
getattr(self, inst.opname)(inst)
File "python3.9/site-packages/torch/_dynamo/symbolic_convert.py", line 1975, in LOAD_CLOSURE
super().run()
File "python3.9/site-packages/torch/_dynamo/symbolic_convert.py", line 576, in run
self.push(self.closure_cells[inst.argval])
KeyError: fn
...
from user code:
File "trainer.py", line 214, in <graph break in _output_figure>
self.logger.experiment.log({exp_log_tag: wandb.Image(save_path)})
File "python3.9/site-packages/lightning_fabric/loggers/logger.py", line 114, in experiment
def get_experiment() -> Callable:
Set torch._dynamo.config.verbose=True for more information
Environment
Current environment
#- Lightning Component (e.g. Trainer, LightningModule, LightningApp, LightningWork, LightningFlow): LightningModule
#- PyTorch Lightning Version (e.g., 1.5.0): 2.0
#- Lightning App Version (e.g., 0.5.2): n/a
#- PyTorch Version (e.g., 2.0): 2.0
#- Python version (e.g., 3.9): 3.9
#- OS (e.g., Linux): Linux
#- CUDA/cuDNN version: 11.8
#- GPU models and configuration: A100
#- How you installed Lightning(`conda`, `pip`, source): conda
#- Running environment of LightningApp (e.g. local, cloud): n/a
More info
No response
cc @carmocca