2525import jax .numpy as jnp
2626from flax .core .frozen_dict import FrozenDict , freeze , unfreeze
2727from flax .linen import combine_masks , make_causal_mask
28+ from flax .linen import partitioning as nn_partitioning
2829from flax .linen .attention import dot_product_attention_weights
2930from flax .traverse_util import flatten_dict , unflatten_dict
3031from jax .random import PRNGKey
5354_CONFIG_FOR_DOC = "LongT5Config"
5455_TOKENIZER_FOR_DOC = "T5Tokenizer"
5556
57+ remat = nn_partitioning .remat
58+
5659
5760# Copied from transformers.models.bart.modeling_flax_bart.shift_tokens_right
5861def shift_tokens_right (input_ids : np .array , pad_token_id : int , decoder_start_token_id : int ) -> np .ndarray :
@@ -1356,7 +1359,6 @@ def __call__(
13561359 encoder_attention_mask = None ,
13571360 encoder_decoder_position_bias = None ,
13581361 output_attentions = False ,
1359- return_dict = True ,
13601362 deterministic = True ,
13611363 init_cache = False ,
13621364 ):
@@ -1377,13 +1379,31 @@ def __call__(
13771379class FlaxLongT5BlockCollection (nn .Module ):
13781380 config : LongT5Config
13791381 dtype : jnp .dtype = jnp .float32 # the dtype of the computation
1382+ gradient_checkpointing : bool = False
13801383
13811384 def setup (self ):
13821385 self .causal = self .config .causal
1383- self .blocks = [
1384- FlaxLongT5LayerCollection (self .config , has_relative_attention_bias = (i == 0 ), dtype = self .dtype , name = str (i ))
1385- for i in range (self .config .num_layers )
1386- ]
1386+ if self .gradient_checkpointing :
1387+ FlaxLongT5CheckpointLayer = remat (FlaxLongT5LayerCollection , static_argnums = (6 , 7 , 8 ))
1388+ self .blocks = [
1389+ FlaxLongT5CheckpointLayer (
1390+ self .config ,
1391+ has_relative_attention_bias = (i == 0 ),
1392+ dtype = self .dtype ,
1393+ name = str (i ),
1394+ )
1395+ for i in range (self .config .num_layers )
1396+ ]
1397+ else :
1398+ self .blocks = [
1399+ FlaxLongT5LayerCollection (
1400+ self .config ,
1401+ has_relative_attention_bias = (i == 0 ),
1402+ dtype = self .dtype ,
1403+ name = str (i ),
1404+ )
1405+ for i in range (self .config .num_layers )
1406+ ]
13871407
13881408 def __call__ (
13891409 self ,
@@ -1409,14 +1429,14 @@ def __call__(
14091429
14101430 layer_outputs = layer_module (
14111431 hidden_states ,
1412- attention_mask = attention_mask ,
1413- position_bias = position_bias ,
1414- encoder_hidden_states = encoder_hidden_states ,
1415- encoder_attention_mask = encoder_attention_mask ,
1416- encoder_decoder_position_bias = encoder_decoder_position_bias ,
1417- output_attentions = output_attentions ,
1418- deterministic = deterministic ,
1419- init_cache = init_cache ,
1432+ attention_mask ,
1433+ position_bias ,
1434+ encoder_hidden_states ,
1435+ encoder_attention_mask ,
1436+ encoder_decoder_position_bias ,
1437+ output_attentions ,
1438+ deterministic ,
1439+ init_cache ,
14201440 )
14211441
14221442 hidden_states = layer_outputs [0 ]
@@ -1447,11 +1467,14 @@ class FlaxLongT5Stack(nn.Module):
14471467 config : LongT5Config
14481468 embed_tokens : nn .Embed
14491469 dtype : jnp .dtype = jnp .float32 # the dtype of the computation
1470+ gradient_checkpointing : bool = False
14501471
14511472 def setup (self ):
14521473 self .causal = self .config .causal
14531474
1454- self .block = FlaxLongT5BlockCollection (self .config , dtype = self .dtype )
1475+ self .block = FlaxLongT5BlockCollection (
1476+ self .config , dtype = self .dtype , gradient_checkpointing = self .gradient_checkpointing
1477+ )
14551478 self .final_layer_norm = FlaxLongT5LayerNorm (
14561479 self .config .d_model , eps = self .config .layer_norm_epsilon , dtype = self .dtype
14571480 )
@@ -1989,6 +2012,7 @@ def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, **kwargs
19892012class FlaxLongT5Module (nn .Module ):
19902013 config : LongT5Config
19912014 dtype : jnp .dtype = jnp .float32 # the dtype of the computation
2015+ gradient_checkpointing : bool = False
19922016
19932017 def _get_encoder_module (self ):
19942018 return self .encoder
@@ -2005,12 +2029,22 @@ def setup(self):
20052029
20062030 encoder_config = copy .deepcopy (self .config )
20072031 encoder_config .causal = False
2008- self .encoder = FlaxLongT5Stack (encoder_config , embed_tokens = self .shared , dtype = self .dtype )
2032+ self .encoder = FlaxLongT5Stack (
2033+ encoder_config ,
2034+ embed_tokens = self .shared ,
2035+ dtype = self .dtype ,
2036+ gradient_checkpointing = self .gradient_checkpointing ,
2037+ )
20092038
20102039 decoder_config = copy .deepcopy (self .config )
20112040 decoder_config .causal = True
20122041 decoder_config .num_layers = self .config .num_decoder_layers
2013- self .decoder = FlaxLongT5Stack (decoder_config , embed_tokens = self .shared , dtype = self .dtype )
2042+ self .decoder = FlaxLongT5Stack (
2043+ decoder_config ,
2044+ embed_tokens = self .shared ,
2045+ dtype = self .dtype ,
2046+ gradient_checkpointing = self .gradient_checkpointing ,
2047+ )
20142048
20152049 def __call__ (
20162050 self ,
@@ -2104,6 +2138,7 @@ class FlaxLongT5Model(FlaxLongT5PreTrainedModel):
21042138class FlaxLongT5ForConditionalGenerationModule (nn .Module ):
21052139 config : LongT5Config
21062140 dtype : jnp .dtype = jnp .float32 # the dtype of the computation
2141+ gradient_checkpointing : bool = False
21072142
21082143 def _get_encoder_module (self ):
21092144 return self .encoder
@@ -2124,13 +2159,17 @@ def setup(self):
21242159 encoder_config .causal = False
21252160 encoder_config .use_cache = False
21262161 encoder_config .is_encoder_decoder = False
2127- self .encoder = FlaxLongT5Stack (encoder_config , self .shared , dtype = self .dtype )
2162+ self .encoder = FlaxLongT5Stack (
2163+ encoder_config , self .shared , dtype = self .dtype , gradient_checkpointing = self .gradient_checkpointing
2164+ )
21282165
21292166 decoder_config = copy .deepcopy (self .config )
21302167 decoder_config .causal = True
21312168 decoder_config .is_encoder_decoder = False
21322169 decoder_config .num_layers = self .config .num_decoder_layers
2133- self .decoder = FlaxLongT5Stack (decoder_config , self .shared , dtype = self .dtype )
2170+ self .decoder = FlaxLongT5Stack (
2171+ decoder_config , self .shared , dtype = self .dtype , gradient_checkpointing = self .gradient_checkpointing
2172+ )
21342173
21352174 self .lm_head = nn .Dense (
21362175 self .config .vocab_size ,
0 commit comments