Skip to content

Commit 57b532b

Browse files
HuiGao-NVdominicshanshan
authored andcommitted
Add debug hook to support dump tensor data and add new debug functions easily (NVIDIA#5182)
Signed-off-by: Hui Gao
1 parent 8d70c9d commit 57b532b

File tree

3 files changed

+553
-0
lines changed

3 files changed

+553
-0
lines changed

tensorrt_llm/_torch/debug/__init__.py

Whitespace-only changes.
Lines changed: 373 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,373 @@
1+
import inspect
2+
import os
3+
from contextlib import contextmanager
4+
from pathlib import Path
5+
from typing import Optional
6+
7+
import torch
8+
import torch.nn as nn
9+
10+
import tensorrt_llm
11+
12+
13+
class DebuggerContext:
14+
"""
15+
A context container which contains the running states, such as the layer structures,
16+
log folder, hooks to run, is pre_forward or after forward, etc.
17+
18+
Arguments:
19+
dest_folder: str
20+
The working directory set to debug context to set where the hook dumped data/info.
21+
"""
22+
23+
def __init__(self, dest_folder: str = None):
24+
self.pre_forward_actions = []
25+
self.after_forward_actions = []
26+
27+
self.layer_names = []
28+
self.layer_inner_counter = []
29+
30+
self.module_forward_hook_handle = None
31+
self.module_forward_pre_hook_handle = None
32+
33+
self.forward_hook_handles = {} # module to handlers
34+
self.forward_pre_hook_handles = {}
35+
self.log_folder = dest_folder
36+
self.is_pre_forward = True
37+
self._init_log_folder()
38+
39+
def _init_log_folder(self):
40+
if self.log_folder is None:
41+
pwd = os.getcwd()
42+
self.log_folder = os.path.join(pwd, "data_dump")
43+
44+
rank = tensorrt_llm.mpi_rank()
45+
46+
p = Path(self.log_folder) / f"rank{rank}"
47+
self.log_folder = p.absolute()
48+
p.mkdir(parents=True, exist_ok=True)
49+
50+
def get_log_folder(self):
51+
return self.log_folder
52+
53+
def check_in_pre_forward(self):
54+
return self.is_pre_forward
55+
56+
def mark_in_pre_forward(self, is_pre_forward):
57+
self.is_pre_forward = is_pre_forward
58+
59+
def clear_state(self):
60+
self.pre_forward_actions.clear()
61+
self.after_forward_actions.clear()
62+
self.layer_names.clear()
63+
self.layer_inner_counter.clear()
64+
65+
if self.module_forward_hook_handle is not None:
66+
self.module_forward_hook_handle.remove()
67+
if self.module_forward_pre_hook_handle is not None:
68+
self.module_forward_pre_hook_handle.remove()
69+
70+
self.module_forward_hook_handle = None
71+
self.module_forward_pre_hook_handle = None
72+
73+
for _, handler in self.forward_hook_handles.items():
74+
handler.remove()
75+
76+
for _, handler in self.forward_pre_hook_handles.items():
77+
handler.remove()
78+
79+
self.forward_hook_handles.clear()
80+
self.forward_pre_hook_handles.clear()
81+
82+
def register_pre_forward_action(self, filter, action):
83+
self.pre_forward_actions.append((filter, action))
84+
85+
def get_pre_forward_action(self):
86+
return self.pre_forward_actions
87+
88+
def register_after_forward_action(self, filter, action):
89+
self.after_forward_actions.append((filter, action))
90+
91+
def get_after_forward_action(self):
92+
return self.after_forward_actions
93+
94+
def get_current_modules_tree(self):
95+
return self.layer_names
96+
97+
def get_module_indices_tree(self):
98+
return self.layer_inner_counter
99+
100+
def get_current_model_loop_index(self):
101+
return self.layer_inner_counter[0] + 1 if len(
102+
self.layer_inner_counter) >= 1 else 0
103+
104+
def do_actions(self, module, tensors, actions):
105+
assert isinstance(actions, list), "Actions shall be list."
106+
for k, a in actions:
107+
if k.filter(module, tensors):
108+
a(module, tensors, self)
109+
110+
111+
class Filter:
112+
113+
def __init__(self):
114+
pass
115+
116+
def filter(self, module: nn.Module, debug_ctx: DebuggerContext):
117+
raise NotImplementedError("Need to implement filter interface.")
118+
119+
120+
debug_ctx = None
121+
122+
123+
def get_current_debug_ctx():
124+
global debug_ctx
125+
return debug_ctx
126+
127+
128+
def set_current_debug_ctx(ctx):
129+
global debug_ctx
130+
debug_ctx = ctx
131+
132+
133+
def pre_forward(module: nn.Module, args, kwargs):
134+
"""
135+
The hook is registered to module with module.register_forward_pre_hook.
136+
This hook will be executed before module's forward is called.
137+
It will record module tree into debugCtx and call debugCtx's do_actions function
138+
to execute all hooks registered to debugCtx on current module.
139+
Args:
140+
module (nn.Module): the module this hook is executed on.
141+
args: the positional args of module.forward.
142+
kwargs (dict): the kwargs to module.forward
143+
Returns:
144+
None
145+
"""
146+
name = module.name if hasattr(module, "name") else module.__class__.__name__
147+
debug_ctx = get_current_debug_ctx()
148+
assert debug_ctx is not None, "DebugContext instance shall not be None."
149+
debug_ctx.mark_in_pre_forward(True)
150+
debug_ctx.get_current_modules_tree().append(name)
151+
if len(debug_ctx.get_module_indices_tree()) == 0:
152+
debug_ctx.get_module_indices_tree().append(0)
153+
154+
if len(debug_ctx.get_current_modules_tree()) >= len(
155+
debug_ctx.get_module_indices_tree()):
156+
debug_ctx.get_module_indices_tree().append(0)
157+
158+
debug_ctx.get_module_indices_tree()[
159+
len(debug_ctx.get_current_modules_tree()) -
160+
1] = debug_ctx.get_module_indices_tree()[
161+
len(debug_ctx.get_current_modules_tree()) - 1] + 1
162+
debug_ctx.do_actions(module, args, debug_ctx.get_pre_forward_action())
163+
return None
164+
165+
166+
def after_forward(module: nn.Module, args, kwargs, output):
167+
"""
168+
The hook is registered to module with module.register_forward_hook.
169+
This hook will be executed after module's forward is called.
170+
It will remove module from debugCtx and call debugCtx's do_actions function
171+
to execute all hooks registered to debugCtx on current module.
172+
Args:
173+
module (nn.Module): the module this hook is executed on.
174+
args: the positional args of module.forward.
175+
kwargs (dict): the kwargs to module.forward
176+
output: the returned values (tensors) from module.forward()
177+
Returns:
178+
None
179+
"""
180+
debug_ctx = get_current_debug_ctx()
181+
debug_ctx.mark_in_pre_forward(False)
182+
debug_ctx.do_actions(module, [args, output],
183+
debug_ctx.get_after_forward_action())
184+
name = module.name if hasattr(module, "name") else module.__class__.__name__
185+
old_name = debug_ctx.get_current_modules_tree().pop(-1)
186+
assert name == old_name, "module mismatch"
187+
188+
debug_ctx.get_module_indices_tree().pop(-1)
189+
return None
190+
191+
192+
def enable_debug(model: nn.Module,
193+
dest_folder: Optional[str] = None,
194+
filter: Optional[Filter] = None):
195+
"""
196+
The function style to interface to enable debugger on model.
197+
If filter is provided, it will be used to filter out satisfied module to register hook.
198+
If filter is not provided, all modules will be registered with hooks.
199+
Example:
200+
from tensorrt_llm._torch.debug.debug_hook import enable_debug
201+
model_config = ModelConfig(pretrained_config=llama_config,
202+
attn_backend=backend)
203+
llama = LlamaForCausalLM(model_config).to(dtype).to(device)
204+
llama.load_weights(hf_llama.state_dict())
205+
with torch.inference_mode():
206+
enable_debug(llama, r"tensor_dump"):
207+
attn_metadata.prepare()
208+
logits = llama.forward(input_ids=input_ids,
209+
position_ids=position_ids,
210+
attn_metadata=attn_metadata)
211+
212+
Note: this method need user to disable debug by calling disable_debug
213+
Args:
214+
model (nn.Module): the model to enable debug hook.
215+
dest_folder: the working directory set to debug context to set where the hook dumped data/info.
216+
filter: a filter to decide what modules will be registered with debug hook.
217+
Returns:
218+
None
219+
"""
220+
debug_ctx = get_current_debug_ctx()
221+
assert debug_ctx is None, "DebugContext shall be None when enable debugger context."
222+
debug_ctx = DebuggerContext(dest_folder)
223+
set_current_debug_ctx(debug_ctx)
224+
225+
debug_ctx.get_current_modules_tree().clear()
226+
debug_ctx.get_module_indices_tree().clear()
227+
for name, submodule in model.named_modules():
228+
if name == "":
229+
continue
230+
231+
if submodule not in debug_ctx.forward_hook_handles:
232+
do_hook = filter(submodule) if filter is not None else True
233+
if do_hook:
234+
debug_ctx.forward_hook_handles[
235+
submodule] = submodule.register_forward_hook(
236+
after_forward, with_kwargs=True, always_call=True)
237+
238+
if submodule not in debug_ctx.forward_pre_hook_handles:
239+
do_hook = filter(submodule) if filter is not None else True
240+
if do_hook:
241+
debug_ctx.forward_pre_hook_handles[
242+
submodule] = submodule.register_forward_pre_hook(
243+
pre_forward, with_kwargs=True)
244+
245+
246+
def disable_debug():
247+
"""
248+
The function style to interface to disable debugger on model.
249+
"""
250+
debug_ctx = get_current_debug_ctx()
251+
assert debug_ctx is not None, "DebugContext shall be None when enable debugger context."
252+
debug_ctx.clear_state()
253+
for _, handler in debug_ctx.forward_hook_handles.items():
254+
handler.remove()
255+
256+
for _, handler in debug_ctx.forward_pre_hook_handles.items():
257+
handler.remove()
258+
259+
debug_ctx.forward_hook_handles.clear()
260+
debug_ctx.forward_pre_hook_handles.clear()
261+
set_current_debug_ctx(None)
262+
263+
264+
@contextmanager
265+
def debug_mode(model: nn.Module,
266+
dest_folder: Optional[str] = None,
267+
filter: Optional[Filter] = None):
268+
"""
269+
The context manager style interface to enable debugger on model.
270+
If filter is provided, it will be used to filter out satisfied module to register hook.
271+
If filter is not provided, all modules will be registered with hooks.
272+
Example:
273+
from tensorrt_llm._torch.debug.debug_hook import debug_mode
274+
model_config = ModelConfig(pretrained_config=llama_config,
275+
attn_backend=backend)
276+
llama = LlamaForCausalLM(model_config).to(dtype).to(device)
277+
llama.load_weights(hf_llama.state_dict())
278+
with torch.inference_mode() and debug_mode(llama, r"tensor_dump"):
279+
attn_metadata.prepare()
280+
logits = llama.forward(input_ids=input_ids,
281+
position_ids=position_ids,
282+
attn_metadata=attn_metadata)
283+
Args:
284+
model (nn.Module): the model to enable debug hook.
285+
dest_folder: the working directory set to debug context to set where the hook dumped data/info.
286+
filter: a filter to decide what modules will be registered with debug hook.
287+
Returns:
288+
None
289+
"""
290+
try:
291+
enable_debug(model, dest_folder, filter)
292+
register_tensor_dump_hook()
293+
yield model
294+
finally:
295+
disable_debug()
296+
297+
298+
def get_forward_arg_names(module: nn.Module):
299+
if hasattr(module, "forward"):
300+
forward_func = module.forward
301+
args = inspect.getfullargspec(forward_func).args
302+
return args
303+
304+
return None
305+
306+
307+
class DumpTensorFilter(Filter):
308+
"""
309+
Below is one hook for dump tensors.
310+
Normally, if you want implement one hook, you need to implement one filter by
311+
inheriting from base class Filter and one function which defines what to do,
312+
such as dump data, modify data, inject actions, etc.
313+
"""
314+
315+
def __init__(self):
316+
pass
317+
318+
def filter(self, module: nn.Module, debug_ctx: DebuggerContext):
319+
return True
320+
321+
322+
def dump_tensor(module: nn.Module, data_tensor, debug_ctx: DebuggerContext):
323+
tensor_counter = 0
324+
325+
input_tensor_names = get_forward_arg_names(module)
326+
if input_tensor_names is not None:
327+
input_tensor_names = input_tensor_names[1:]
328+
329+
def get_dump_file_path(tensor):
330+
nonlocal tensor_counter
331+
nonlocal input_tensor_names
332+
assert debug_ctx.get_log_folder(
333+
) is not None, "Log folder shall be initialized by DebugContext."
334+
335+
name_parts = []
336+
for idx in range(len(debug_ctx.get_current_modules_tree())):
337+
inner_idx = f"{debug_ctx.get_module_indices_tree()[idx]}"
338+
layer_name = debug_ctx.get_current_modules_tree()[idx]
339+
name_parts.append(".".join([inner_idx, layer_name]))
340+
module_path = "-".join(name_parts)
341+
342+
tensor_type = "input" if debug_ctx.check_in_pre_forward() else "output"
343+
if hasattr(tensor, "name") and tensor.name is not None:
344+
tensor_name = f"{tensor_type}.{tensor.name}.pt"
345+
elif tensor_counter < len(input_tensor_names):
346+
tensor_name = f"{tensor_type}.{input_tensor_names[tensor_counter]}.pt"
347+
else:
348+
tensor_name = f"{tensor_type}.{tensor_counter}.pt"
349+
350+
tensor_counter += 1
351+
module_path = "-".join([module_path, tensor_name])
352+
p = Path(debug_ctx.get_log_folder()) / module_path
353+
return p.absolute()
354+
355+
def dump_tensor_data(t):
356+
file_path = get_dump_file_path(t)
357+
torch.save(t, file_path)
358+
359+
def dump(t):
360+
if isinstance(t, torch.Tensor):
361+
dump_tensor_data(t)
362+
elif isinstance(t, tuple) or isinstance(t, list):
363+
for _t in t:
364+
dump(_t)
365+
366+
dump(data_tensor)
367+
368+
369+
def register_tensor_dump_hook():
370+
debug_ctx = get_current_debug_ctx()
371+
assert debug_ctx is not None, ""
372+
debug_ctx.register_pre_forward_action(DumpTensorFilter(), dump_tensor)
373+
debug_ctx.register_after_forward_action(DumpTensorFilter(), dump_tensor)

0 commit comments

Comments
 (0)