Skip to content

Commit faac087

Browse files
ethanbwaitemeta-codesync[bot]
authored andcommitted
Lazy import matplotlib.pyplot to prevent lockfile errors (#1653)
Summary: Pull Request resolved: #1653 Lazy import matplotlib to avoid lockfile issues with some of the font files when importing captum library in multithreaded environments eg: ``` [2025-10-01T04:02:13.219-07:00] TimeoutError: Lock error: Matplotlib failed to acquire the following lock file: [2025-10-01T04:02:13.219-07:00] /var/twsvcscm/.cache/matplotlib/fontlist-v330.json.matplotlib-lock [2025-10-01T04:02:13.219-07:00] This maybe due to another process holding this lock file. If you are sure no [2025-10-01T04:02:13.219-07:00] other Matplotlib process is running, remove this file and try again. ``` Reviewed By: craymichael Differential Revision: D83707766 fbshipit-source-id: 8a5b9a537946f5856c3cb8c9a2891b9a57daf0a5
1 parent 7dc0506 commit faac087

File tree

1 file changed

+22
-6
lines changed

1 file changed

+22
-6
lines changed

captum/attr/_core/llm_attr.py

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,22 @@
66
from copy import copy
77
from dataclasses import dataclass
88
from textwrap import dedent, shorten
9-
from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Type, Union
9+
10+
from typing import (
11+
Any,
12+
Callable,
13+
cast,
14+
Dict,
15+
List,
16+
Optional,
17+
Tuple,
18+
Type,
19+
TYPE_CHECKING,
20+
Union,
21+
)
1022

1123
import matplotlib.colors as mcolors
1224

13-
import matplotlib.pyplot as plt
1425
import numpy as np
1526
import numpy.typing as npt
1627

@@ -34,6 +45,9 @@
3445
TextTemplateInput,
3546
TextTokenInput,
3647
)
48+
49+
if TYPE_CHECKING:
50+
from matplotlib.pyplot import Axes, Figure
3751
from torch import nn, Tensor
3852

3953
DEFAULT_GEN_ARGS: Dict[str, Any] = {
@@ -145,7 +159,7 @@ def seq_attr_dict(self) -> Dict[str, float]:
145159

146160
def plot_token_attr(
147161
self, show: bool = False
148-
) -> Union[None, Tuple[plt.Figure, plt.Axes]]:
162+
) -> Union[None, Tuple["Figure", "Axes"]]:
149163
"""
150164
Generate a matplotlib plot for visualising the attribution
151165
of the output tokens.
@@ -167,6 +181,8 @@ def plot_token_attr(
167181
# always keep 0 as the mid point to differentiate pos/neg attr
168182
max_abs_attr_val = token_attr.abs().max().item()
169183

184+
import matplotlib.pyplot as plt
185+
170186
fig, ax = plt.subplots()
171187

172188
# Hide the grid
@@ -243,9 +259,7 @@ def plot_token_attr(
243259
else:
244260
return fig, ax
245261

246-
def plot_seq_attr(
247-
self, show: bool = False
248-
) -> Union[None, Tuple[plt.Figure, plt.Axes]]:
262+
def plot_seq_attr(self, show: bool = False) -> Union[None, Tuple["Figure", "Axes"]]:
249263
"""
250264
Generate a matplotlib plot for visualising the attribution
251265
of the output sequence.
@@ -255,6 +269,8 @@ def plot_seq_attr(
255269
Default: False
256270
"""
257271

272+
import matplotlib.pyplot as plt
273+
258274
fig, ax = plt.subplots()
259275

260276
data = self.seq_attr.cpu().numpy()

0 commit comments

Comments
 (0)