Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 22 additions & 6 deletions captum/attr/_core/llm_attr.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,22 @@
from copy import copy
from dataclasses import dataclass
from textwrap import dedent, shorten
from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Type, Union

from typing import (
Any,
Callable,
cast,
Dict,
List,
Optional,
Tuple,
Type,
TYPE_CHECKING,
Union,
)

import matplotlib.colors as mcolors

import matplotlib.pyplot as plt
import numpy as np
import numpy.typing as npt

Expand All @@ -34,6 +45,9 @@
TextTemplateInput,
TextTokenInput,
)

if TYPE_CHECKING:
from matplotlib.pyplot import Axes, Figure
from torch import nn, Tensor

DEFAULT_GEN_ARGS: Dict[str, Any] = {
Expand Down Expand Up @@ -145,7 +159,7 @@ def seq_attr_dict(self) -> Dict[str, float]:

def plot_token_attr(
self, show: bool = False
) -> Union[None, Tuple[plt.Figure, plt.Axes]]:
) -> Union[None, Tuple["Figure", "Axes"]]:
"""
Generate a matplotlib plot for visualising the attribution
of the output tokens.
Expand All @@ -167,6 +181,8 @@ def plot_token_attr(
# always keep 0 as the mid point to differentiate pos/neg attr
max_abs_attr_val = token_attr.abs().max().item()

import matplotlib.pyplot as plt

fig, ax = plt.subplots()

# Hide the grid
Expand Down Expand Up @@ -243,9 +259,7 @@ def plot_token_attr(
else:
return fig, ax

def plot_seq_attr(
self, show: bool = False
) -> Union[None, Tuple[plt.Figure, plt.Axes]]:
def plot_seq_attr(self, show: bool = False) -> Union[None, Tuple["Figure", "Axes"]]:
"""
Generate a matplotlib plot for visualising the attribution
of the output sequence.
Expand All @@ -255,6 +269,8 @@ def plot_seq_attr(
Default: False
"""

import matplotlib.pyplot as plt

fig, ax = plt.subplots()

data = self.seq_attr.cpu().numpy()
Expand Down
Loading