| 
 | 1 | +import inspect  | 
 | 2 | +import os  | 
 | 3 | +from contextlib import contextmanager  | 
 | 4 | +from pathlib import Path  | 
 | 5 | +from typing import Optional  | 
 | 6 | + | 
 | 7 | +import torch  | 
 | 8 | +import torch.nn as nn  | 
 | 9 | + | 
 | 10 | +import tensorrt_llm  | 
 | 11 | + | 
 | 12 | + | 
 | 13 | +class DebuggerContext:  | 
 | 14 | +    """  | 
 | 15 | +    A context container which contains the running states, such as the layer structures,  | 
 | 16 | +    log folder, hooks to run, is pre_forward or after forward, etc.  | 
 | 17 | +
  | 
 | 18 | +    Arguments:  | 
 | 19 | +        dest_folder: str  | 
 | 20 | +                The working directory set to debug context to set where the hook dumped data/info.  | 
 | 21 | +    """  | 
 | 22 | + | 
 | 23 | +    def __init__(self, dest_folder: str = None):  | 
 | 24 | +        self.pre_forward_actions = []  | 
 | 25 | +        self.after_forward_actions = []  | 
 | 26 | + | 
 | 27 | +        self.layer_names = []  | 
 | 28 | +        self.layer_inner_counter = []  | 
 | 29 | + | 
 | 30 | +        self.module_forward_hook_handle = None  | 
 | 31 | +        self.module_forward_pre_hook_handle = None  | 
 | 32 | + | 
 | 33 | +        self.forward_hook_handles = {}  # module to handlers  | 
 | 34 | +        self.forward_pre_hook_handles = {}  | 
 | 35 | +        self.log_folder = dest_folder  | 
 | 36 | +        self.is_pre_forward = True  | 
 | 37 | +        self._init_log_folder()  | 
 | 38 | + | 
 | 39 | +    def _init_log_folder(self):  | 
 | 40 | +        if self.log_folder is None:  | 
 | 41 | +            pwd = os.getcwd()  | 
 | 42 | +            self.log_folder = os.path.join(pwd, "data_dump")  | 
 | 43 | + | 
 | 44 | +        rank = tensorrt_llm.mpi_rank()  | 
 | 45 | + | 
 | 46 | +        p = Path(self.log_folder) / f"rank{rank}"  | 
 | 47 | +        self.log_folder = p.absolute()  | 
 | 48 | +        p.mkdir(parents=True, exist_ok=True)  | 
 | 49 | + | 
 | 50 | +    def get_log_folder(self):  | 
 | 51 | +        return self.log_folder  | 
 | 52 | + | 
 | 53 | +    def check_in_pre_forward(self):  | 
 | 54 | +        return self.is_pre_forward  | 
 | 55 | + | 
 | 56 | +    def mark_in_pre_forward(self, is_pre_forward):  | 
 | 57 | +        self.is_pre_forward = is_pre_forward  | 
 | 58 | + | 
 | 59 | +    def clear_state(self):  | 
 | 60 | +        self.pre_forward_actions.clear()  | 
 | 61 | +        self.after_forward_actions.clear()  | 
 | 62 | +        self.layer_names.clear()  | 
 | 63 | +        self.layer_inner_counter.clear()  | 
 | 64 | + | 
 | 65 | +        if self.module_forward_hook_handle is not None:  | 
 | 66 | +            self.module_forward_hook_handle.remove()  | 
 | 67 | +        if self.module_forward_pre_hook_handle is not None:  | 
 | 68 | +            self.module_forward_pre_hook_handle.remove()  | 
 | 69 | + | 
 | 70 | +        self.module_forward_hook_handle = None  | 
 | 71 | +        self.module_forward_pre_hook_handle = None  | 
 | 72 | + | 
 | 73 | +        for _, handler in self.forward_hook_handles.items():  | 
 | 74 | +            handler.remove()  | 
 | 75 | + | 
 | 76 | +        for _, handler in self.forward_pre_hook_handles.items():  | 
 | 77 | +            handler.remove()  | 
 | 78 | + | 
 | 79 | +        self.forward_hook_handles.clear()  | 
 | 80 | +        self.forward_pre_hook_handles.clear()  | 
 | 81 | + | 
 | 82 | +    def register_pre_forward_action(self, filter, action):  | 
 | 83 | +        self.pre_forward_actions.append((filter, action))  | 
 | 84 | + | 
 | 85 | +    def get_pre_forward_action(self):  | 
 | 86 | +        return self.pre_forward_actions  | 
 | 87 | + | 
 | 88 | +    def register_after_forward_action(self, filter, action):  | 
 | 89 | +        self.after_forward_actions.append((filter, action))  | 
 | 90 | + | 
 | 91 | +    def get_after_forward_action(self):  | 
 | 92 | +        return self.after_forward_actions  | 
 | 93 | + | 
 | 94 | +    def get_current_modules_tree(self):  | 
 | 95 | +        return self.layer_names  | 
 | 96 | + | 
 | 97 | +    def get_module_indices_tree(self):  | 
 | 98 | +        return self.layer_inner_counter  | 
 | 99 | + | 
 | 100 | +    def get_current_model_loop_index(self):  | 
 | 101 | +        return self.layer_inner_counter[0] + 1 if len(  | 
 | 102 | +            self.layer_inner_counter) >= 1 else 0  | 
 | 103 | + | 
 | 104 | +    def do_actions(self, module, tensors, actions):  | 
 | 105 | +        assert isinstance(actions, list), "Actions shall be list."  | 
 | 106 | +        for k, a in actions:  | 
 | 107 | +            if k.filter(module, tensors):  | 
 | 108 | +                a(module, tensors, self)  | 
 | 109 | + | 
 | 110 | + | 
 | 111 | +class Filter:  | 
 | 112 | + | 
 | 113 | +    def __init__(self):  | 
 | 114 | +        pass  | 
 | 115 | + | 
 | 116 | +    def filter(self, module: nn.Module, debug_ctx: DebuggerContext):  | 
 | 117 | +        raise NotImplementedError("Need to implement filter interface.")  | 
 | 118 | + | 
 | 119 | + | 
 | 120 | +debug_ctx = None  | 
 | 121 | + | 
 | 122 | + | 
 | 123 | +def get_current_debug_ctx():  | 
 | 124 | +    global debug_ctx  | 
 | 125 | +    return debug_ctx  | 
 | 126 | + | 
 | 127 | + | 
 | 128 | +def set_current_debug_ctx(ctx):  | 
 | 129 | +    global debug_ctx  | 
 | 130 | +    debug_ctx = ctx  | 
 | 131 | + | 
 | 132 | + | 
 | 133 | +def pre_forward(module: nn.Module, args, kwargs):  | 
 | 134 | +    """  | 
 | 135 | +    The hook is registered to module with module.register_forward_pre_hook.  | 
 | 136 | +    This hook will be executed before module's forward is called.  | 
 | 137 | +    It will record module tree into debugCtx and call debugCtx's do_actions function  | 
 | 138 | +    to execute all hooks registered to debugCtx on current module.  | 
 | 139 | +    Args:  | 
 | 140 | +        module (nn.Module): the module this hook is executed on.  | 
 | 141 | +        args: the positional args of module.forward.  | 
 | 142 | +        kwargs (dict): the kwargs to module.forward  | 
 | 143 | +    Returns:  | 
 | 144 | +        None  | 
 | 145 | +    """  | 
 | 146 | +    name = module.name if hasattr(module, "name") else module.__class__.__name__  | 
 | 147 | +    debug_ctx = get_current_debug_ctx()  | 
 | 148 | +    assert debug_ctx is not None, "DebugContext instance shall not be None."  | 
 | 149 | +    debug_ctx.mark_in_pre_forward(True)  | 
 | 150 | +    debug_ctx.get_current_modules_tree().append(name)  | 
 | 151 | +    if len(debug_ctx.get_module_indices_tree()) == 0:  | 
 | 152 | +        debug_ctx.get_module_indices_tree().append(0)  | 
 | 153 | + | 
 | 154 | +    if len(debug_ctx.get_current_modules_tree()) >= len(  | 
 | 155 | +            debug_ctx.get_module_indices_tree()):  | 
 | 156 | +        debug_ctx.get_module_indices_tree().append(0)  | 
 | 157 | + | 
 | 158 | +    debug_ctx.get_module_indices_tree()[  | 
 | 159 | +        len(debug_ctx.get_current_modules_tree()) -  | 
 | 160 | +        1] = debug_ctx.get_module_indices_tree()[  | 
 | 161 | +            len(debug_ctx.get_current_modules_tree()) - 1] + 1  | 
 | 162 | +    debug_ctx.do_actions(module, args, debug_ctx.get_pre_forward_action())  | 
 | 163 | +    return None  | 
 | 164 | + | 
 | 165 | + | 
 | 166 | +def after_forward(module: nn.Module, args, kwargs, output):  | 
 | 167 | +    """  | 
 | 168 | +    The hook is registered to module with module.register_forward_hook.  | 
 | 169 | +    This hook will be executed after module's forward is called.  | 
 | 170 | +    It will remove module from debugCtx and call debugCtx's do_actions function  | 
 | 171 | +    to execute all hooks registered to debugCtx on current module.  | 
 | 172 | +    Args:  | 
 | 173 | +        module (nn.Module): the module this hook is executed on.  | 
 | 174 | +        args: the positional args of module.forward.  | 
 | 175 | +        kwargs (dict): the kwargs to module.forward  | 
 | 176 | +        output: the returned values (tensors) from module.forward()  | 
 | 177 | +    Returns:  | 
 | 178 | +        None  | 
 | 179 | +    """  | 
 | 180 | +    debug_ctx = get_current_debug_ctx()  | 
 | 181 | +    debug_ctx.mark_in_pre_forward(False)  | 
 | 182 | +    debug_ctx.do_actions(module, [args, output],  | 
 | 183 | +                         debug_ctx.get_after_forward_action())  | 
 | 184 | +    name = module.name if hasattr(module, "name") else module.__class__.__name__  | 
 | 185 | +    old_name = debug_ctx.get_current_modules_tree().pop(-1)  | 
 | 186 | +    assert name == old_name, "module mismatch"  | 
 | 187 | + | 
 | 188 | +    debug_ctx.get_module_indices_tree().pop(-1)  | 
 | 189 | +    return None  | 
 | 190 | + | 
 | 191 | + | 
 | 192 | +def enable_debug(model: nn.Module,  | 
 | 193 | +                 dest_folder: Optional[str] = None,  | 
 | 194 | +                 filter: Optional[Filter] = None):  | 
 | 195 | +    """  | 
 | 196 | +    The function style to interface to enable debugger on model.  | 
 | 197 | +    If filter is provided, it will be used to filter out satisfied module to register hook.  | 
 | 198 | +    If filter is not provided, all modules will be registered with hooks.  | 
 | 199 | +    Example:  | 
 | 200 | +        from tensorrt_llm._torch.debug.debug_hook import enable_debug  | 
 | 201 | +        model_config = ModelConfig(pretrained_config=llama_config,  | 
 | 202 | +                                    attn_backend=backend)  | 
 | 203 | +        llama = LlamaForCausalLM(model_config).to(dtype).to(device)  | 
 | 204 | +        llama.load_weights(hf_llama.state_dict())  | 
 | 205 | +        with torch.inference_mode():  | 
 | 206 | +            enable_debug(llama, r"tensor_dump"):  | 
 | 207 | +            attn_metadata.prepare()  | 
 | 208 | +            logits = llama.forward(input_ids=input_ids,  | 
 | 209 | +                                    position_ids=position_ids,  | 
 | 210 | +                                    attn_metadata=attn_metadata)  | 
 | 211 | +
  | 
 | 212 | +    Note: this method need user to disable debug by calling disable_debug  | 
 | 213 | +    Args:  | 
 | 214 | +        model (nn.Module): the model to enable debug hook.  | 
 | 215 | +        dest_folder: the working directory set to debug context to set where the hook dumped data/info.  | 
 | 216 | +        filter: a filter to decide what modules will be registered with debug hook.  | 
 | 217 | +    Returns:  | 
 | 218 | +        None  | 
 | 219 | +    """  | 
 | 220 | +    debug_ctx = get_current_debug_ctx()  | 
 | 221 | +    assert debug_ctx is None, "DebugContext shall be None when enable debugger context."  | 
 | 222 | +    debug_ctx = DebuggerContext(dest_folder)  | 
 | 223 | +    set_current_debug_ctx(debug_ctx)  | 
 | 224 | + | 
 | 225 | +    debug_ctx.get_current_modules_tree().clear()  | 
 | 226 | +    debug_ctx.get_module_indices_tree().clear()  | 
 | 227 | +    for name, submodule in model.named_modules():  | 
 | 228 | +        if name == "":  | 
 | 229 | +            continue  | 
 | 230 | + | 
 | 231 | +        if submodule not in debug_ctx.forward_hook_handles:  | 
 | 232 | +            do_hook = filter(submodule) if filter is not None else True  | 
 | 233 | +            if do_hook:  | 
 | 234 | +                debug_ctx.forward_hook_handles[  | 
 | 235 | +                    submodule] = submodule.register_forward_hook(  | 
 | 236 | +                        after_forward, with_kwargs=True, always_call=True)  | 
 | 237 | + | 
 | 238 | +        if submodule not in debug_ctx.forward_pre_hook_handles:  | 
 | 239 | +            do_hook = filter(submodule) if filter is not None else True  | 
 | 240 | +            if do_hook:  | 
 | 241 | +                debug_ctx.forward_pre_hook_handles[  | 
 | 242 | +                    submodule] = submodule.register_forward_pre_hook(  | 
 | 243 | +                        pre_forward, with_kwargs=True)  | 
 | 244 | + | 
 | 245 | + | 
 | 246 | +def disable_debug():  | 
 | 247 | +    """  | 
 | 248 | +    The function style to interface to disable debugger on model.  | 
 | 249 | +    """  | 
 | 250 | +    debug_ctx = get_current_debug_ctx()  | 
 | 251 | +    assert debug_ctx is not None, "DebugContext shall be None when enable debugger context."  | 
 | 252 | +    debug_ctx.clear_state()  | 
 | 253 | +    for _, handler in debug_ctx.forward_hook_handles.items():  | 
 | 254 | +        handler.remove()  | 
 | 255 | + | 
 | 256 | +    for _, handler in debug_ctx.forward_pre_hook_handles.items():  | 
 | 257 | +        handler.remove()  | 
 | 258 | + | 
 | 259 | +    debug_ctx.forward_hook_handles.clear()  | 
 | 260 | +    debug_ctx.forward_pre_hook_handles.clear()  | 
 | 261 | +    set_current_debug_ctx(None)  | 
 | 262 | + | 
 | 263 | + | 
 | 264 | +@contextmanager  | 
 | 265 | +def debug_mode(model: nn.Module,  | 
 | 266 | +               dest_folder: Optional[str] = None,  | 
 | 267 | +               filter: Optional[Filter] = None):  | 
 | 268 | +    """  | 
 | 269 | +    The context manager style interface to enable debugger on model.  | 
 | 270 | +    If filter is provided, it will be used to filter out satisfied module to register hook.  | 
 | 271 | +    If filter is not provided, all modules will be registered with hooks.  | 
 | 272 | +    Example:  | 
 | 273 | +        from tensorrt_llm._torch.debug.debug_hook import debug_mode  | 
 | 274 | +        model_config = ModelConfig(pretrained_config=llama_config,  | 
 | 275 | +                                    attn_backend=backend)  | 
 | 276 | +        llama = LlamaForCausalLM(model_config).to(dtype).to(device)  | 
 | 277 | +        llama.load_weights(hf_llama.state_dict())  | 
 | 278 | +        with torch.inference_mode() and debug_mode(llama, r"tensor_dump"):  | 
 | 279 | +            attn_metadata.prepare()  | 
 | 280 | +            logits = llama.forward(input_ids=input_ids,  | 
 | 281 | +                                    position_ids=position_ids,  | 
 | 282 | +                                    attn_metadata=attn_metadata)  | 
 | 283 | +    Args:  | 
 | 284 | +        model (nn.Module): the model to enable debug hook.  | 
 | 285 | +        dest_folder: the working directory set to debug context to set where the hook dumped data/info.  | 
 | 286 | +        filter: a filter to decide what modules will be registered with debug hook.  | 
 | 287 | +    Returns:  | 
 | 288 | +        None  | 
 | 289 | +    """  | 
 | 290 | +    try:  | 
 | 291 | +        enable_debug(model, dest_folder, filter)  | 
 | 292 | +        register_tensor_dump_hook()  | 
 | 293 | +        yield model  | 
 | 294 | +    finally:  | 
 | 295 | +        disable_debug()  | 
 | 296 | + | 
 | 297 | + | 
 | 298 | +def get_forward_arg_names(module: nn.Module):  | 
 | 299 | +    if hasattr(module, "forward"):  | 
 | 300 | +        forward_func = module.forward  | 
 | 301 | +        args = inspect.getfullargspec(forward_func).args  | 
 | 302 | +        return args  | 
 | 303 | + | 
 | 304 | +    return None  | 
 | 305 | + | 
 | 306 | + | 
 | 307 | +class DumpTensorFilter(Filter):  | 
 | 308 | +    """  | 
 | 309 | +    Below is one hook for dump tensors.  | 
 | 310 | +    Normally, if you want implement one hook, you need to implement one filter by  | 
 | 311 | +    inheriting from base class Filter and one function which defines what to do,  | 
 | 312 | +    such as dump data, modify data, inject actions, etc.  | 
 | 313 | +    """  | 
 | 314 | + | 
 | 315 | +    def __init__(self):  | 
 | 316 | +        pass  | 
 | 317 | + | 
 | 318 | +    def filter(self, module: nn.Module, debug_ctx: DebuggerContext):  | 
 | 319 | +        return True  | 
 | 320 | + | 
 | 321 | + | 
 | 322 | +def dump_tensor(module: nn.Module, data_tensor, debug_ctx: DebuggerContext):  | 
 | 323 | +    tensor_counter = 0  | 
 | 324 | + | 
 | 325 | +    input_tensor_names = get_forward_arg_names(module)  | 
 | 326 | +    if input_tensor_names is not None:  | 
 | 327 | +        input_tensor_names = input_tensor_names[1:]  | 
 | 328 | + | 
 | 329 | +    def get_dump_file_path(tensor):  | 
 | 330 | +        nonlocal tensor_counter  | 
 | 331 | +        nonlocal input_tensor_names  | 
 | 332 | +        assert debug_ctx.get_log_folder(  | 
 | 333 | +        ) is not None, "Log folder shall be initialized by DebugContext."  | 
 | 334 | + | 
 | 335 | +        name_parts = []  | 
 | 336 | +        for idx in range(len(debug_ctx.get_current_modules_tree())):  | 
 | 337 | +            inner_idx = f"{debug_ctx.get_module_indices_tree()[idx]}"  | 
 | 338 | +            layer_name = debug_ctx.get_current_modules_tree()[idx]  | 
 | 339 | +            name_parts.append(".".join([inner_idx, layer_name]))  | 
 | 340 | +        module_path = "-".join(name_parts)  | 
 | 341 | + | 
 | 342 | +        tensor_type = "input" if debug_ctx.check_in_pre_forward() else "output"  | 
 | 343 | +        if hasattr(tensor, "name") and tensor.name is not None:  | 
 | 344 | +            tensor_name = f"{tensor_type}.{tensor.name}.pt"  | 
 | 345 | +        elif tensor_counter < len(input_tensor_names):  | 
 | 346 | +            tensor_name = f"{tensor_type}.{input_tensor_names[tensor_counter]}.pt"  | 
 | 347 | +        else:  | 
 | 348 | +            tensor_name = f"{tensor_type}.{tensor_counter}.pt"  | 
 | 349 | + | 
 | 350 | +        tensor_counter += 1  | 
 | 351 | +        module_path = "-".join([module_path, tensor_name])  | 
 | 352 | +        p = Path(debug_ctx.get_log_folder()) / module_path  | 
 | 353 | +        return p.absolute()  | 
 | 354 | + | 
 | 355 | +    def dump_tensor_data(t):  | 
 | 356 | +        file_path = get_dump_file_path(t)  | 
 | 357 | +        torch.save(t, file_path)  | 
 | 358 | + | 
 | 359 | +    def dump(t):  | 
 | 360 | +        if isinstance(t, torch.Tensor):  | 
 | 361 | +            dump_tensor_data(t)  | 
 | 362 | +        elif isinstance(t, tuple) or isinstance(t, list):  | 
 | 363 | +            for _t in t:  | 
 | 364 | +                dump(_t)  | 
 | 365 | + | 
 | 366 | +    dump(data_tensor)  | 
 | 367 | + | 
 | 368 | + | 
 | 369 | +def register_tensor_dump_hook():  | 
 | 370 | +    debug_ctx = get_current_debug_ctx()  | 
 | 371 | +    assert debug_ctx is not None, ""  | 
 | 372 | +    debug_ctx.register_pre_forward_action(DumpTensorFilter(), dump_tensor)  | 
 | 373 | +    debug_ctx.register_after_forward_action(DumpTensorFilter(), dump_tensor)  | 
0 commit comments