Skip to content

Commit 7de8a3a

Browse files
merrymercytqchen
authored andcommitted
[TOPI] Memoize winograd matrix (#3687)
* [TOPI] Memoize winograd matrix * lint * Fix name
1 parent 33ab3c6 commit 7de8a3a

File tree

2 files changed

+14
-5
lines changed

2 files changed

+14
-5
lines changed

python/tvm/contrib/pickle_memoize.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,11 @@ class Cache(object):
3434
----------
3535
key: str
3636
The file key to the function
37+
save_at_exit: bool
38+
Whether save the cache to file when the program exits
3739
"""
3840
cache_by_key = {}
39-
def __init__(self, key):
41+
def __init__(self, key, save_at_exit):
4042
cache_dir = ".pkl_memoize_py{0}".format(sys.version_info[0])
4143
if not os.path.exists(cache_dir):
4244
os.mkdir(cache_dir)
@@ -49,6 +51,7 @@ def __init__(self, key):
4951
else:
5052
self.cache = {}
5153
self.dirty = False
54+
self.save_at_exit = save_at_exit
5255

5356
def save(self):
5457
if self.dirty:
@@ -60,16 +63,19 @@ def save(self):
6063
def _atexit():
6164
"""Save handler."""
6265
for value in Cache.cache_by_key.values():
63-
value.save()
66+
if value.save_at_exit:
67+
value.save()
6468

6569

66-
def memoize(key):
70+
def memoize(key, save_at_exit=False):
6771
"""Memoize the result of function and reuse multiple times.
6872
6973
Parameters
7074
----------
7175
key: str
7276
The unique key to the file
77+
save_at_exit: bool
78+
Whether save the cache to file when the program exits
7379
7480
Returns
7581
-------
@@ -81,9 +87,9 @@ def _register(f):
8187
allow_types = (string_types, int, float)
8288
fkey = key + "." + f.__name__ + ".pkl"
8389
if fkey not in Cache.cache_by_key:
84-
Cache.cache_by_key[fkey] = Cache(fkey)
90+
Cache.cache_by_key[fkey] = Cache(fkey, save_at_exit)
8591
cache = Cache.cache_by_key[fkey]
86-
cargs = tuple(x.cell_contents for x in f.__closure__)
92+
cargs = tuple(x.cell_contents for x in f.__closure__) if f.__closure__ else ()
8793
cargs = (len(cargs),) + cargs
8894

8995
def _memoized_f(func, *args, **kwargs):

topi/python/topi/nn/winograd_util.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from operator import mul
2626
from functools import reduce
2727
import numpy as np
28+
from tvm.contrib.pickle_memoize import memoize
2829
from ..util import const_matrix
2930

3031

@@ -131,6 +132,8 @@ def _interpolation_points(degree):
131132

132133
return np.array(in_pts[degree-1], dtype=np.float64)
133134

135+
136+
@memoize("topi.nn.winograd_matrices", save_at_exit=False)
134137
def winograd_transform_matrices(tile_size, kernel_size, out_dtype):
135138
"""Compute the A, B, and G transform matrices for `tile_size` as a `tvm.Expr`.
136139
"""

0 commit comments

Comments
 (0)