diff --git a/captum/attr/_core/llm_attr.py b/captum/attr/_core/llm_attr.py index 0e6ffd1dc..c7422d8d9 100644 --- a/captum/attr/_core/llm_attr.py +++ b/captum/attr/_core/llm_attr.py @@ -6,11 +6,22 @@ from copy import copy from dataclasses import dataclass from textwrap import dedent, shorten -from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Type, Union + +from typing import ( + Any, + Callable, + cast, + Dict, + List, + Optional, + Tuple, + Type, + TYPE_CHECKING, + Union, +) import matplotlib.colors as mcolors -import matplotlib.pyplot as plt import numpy as np import numpy.typing as npt @@ -34,6 +45,9 @@ TextTemplateInput, TextTokenInput, ) + +if TYPE_CHECKING: + from matplotlib.pyplot import Axes, Figure from torch import nn, Tensor DEFAULT_GEN_ARGS: Dict[str, Any] = { @@ -145,7 +159,7 @@ def seq_attr_dict(self) -> Dict[str, float]: def plot_token_attr( self, show: bool = False - ) -> Union[None, Tuple[plt.Figure, plt.Axes]]: + ) -> Union[None, Tuple["Figure", "Axes"]]: """ Generate a matplotlib plot for visualising the attribution of the output tokens. @@ -167,6 +181,8 @@ def plot_token_attr( # always keep 0 as the mid point to differentiate pos/neg attr max_abs_attr_val = token_attr.abs().max().item() + import matplotlib.pyplot as plt + fig, ax = plt.subplots() # Hide the grid @@ -243,9 +259,7 @@ def plot_token_attr( else: return fig, ax - def plot_seq_attr( - self, show: bool = False - ) -> Union[None, Tuple[plt.Figure, plt.Axes]]: + def plot_seq_attr(self, show: bool = False) -> Union[None, Tuple["Figure", "Axes"]]: """ Generate a matplotlib plot for visualising the attribution of the output sequence. @@ -255,6 +269,8 @@ def plot_seq_attr( Default: False """ + import matplotlib.pyplot as plt + fig, ax = plt.subplots() data = self.seq_attr.cpu().numpy()