Skip to content

Commit 4d9b76a

Browse files
authored
Fix RWKV backward on GPU (#23774)
1 parent 8d28dba commit 4d9b76a

File tree

1 file changed

+3
-6
lines changed

1 file changed

+3
-6
lines changed

src/transformers/models/rwkv/modeling_rwkv.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ def forward(ctx, time_decay, time_first, key, value, state=None, return_state=Fa
159159

160160
@staticmethod
161161
# g stands for grad
162-
def backward(ctx, g_output):
162+
def backward(ctx, g_output, g_state=None):
163163
input_dtype = ctx.input_dtype
164164

165165
time_decay, time_first, key, value, output = ctx.saved_tensors
@@ -188,17 +188,14 @@ def backward(ctx, g_output):
188188
g_key,
189189
g_value,
190190
)
191-
g_time_decay = torch.sum(g_time_decay, dim=0)
192-
g_time_first = torch.sum(g_time_first, dim=0)
193191

194192
return (
195-
None,
196-
None,
197-
None,
198193
g_time_decay.to(input_dtype),
199194
g_time_first.to(input_dtype),
200195
g_key.to(input_dtype),
201196
g_value.to(input_dtype),
197+
None,
198+
None,
202199
)
203200

204201

0 commit comments

Comments
 (0)