File tree Expand file tree Collapse file tree 2 files changed +21
-6
lines changed
keras_nlp/src/models/stable_diffusion_v3 Expand file tree Collapse file tree 2 files changed +21
-6
lines changed Original file line number Diff line number Diff line change 1111# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212# See the License for the specific language governing permissions and
1313# limitations under the License.
14+ import keras
1415from keras import layers
1516from keras import ops
1617
1718from keras_nlp .src .layers .modeling .token_and_position_embedding import (
1819 TokenAndPositionEmbedding ,
1920)
20- from keras_nlp .src .models .backbone import Backbone
2121from keras_nlp .src .models .stable_diffusion_v3 .clip_encoder_block import (
2222 CLIPEncoderBlock ,
2323)
2424
2525
26- class CLIPTextEncoder (Backbone ):
26+ class CLIPTextEncoder (keras . Model ):
2727 def __init__ (
2828 self ,
2929 embedding_dim ,
@@ -108,7 +108,6 @@ def __init__(
108108 super ().__init__ (
109109 inputs = {"encoder_token_ids" : encoder_token_ids },
110110 outputs = outputs ,
111- dtype = dtype ,
112111 ** kwargs ,
113112 )
114113
@@ -123,6 +122,15 @@ def __init__(
123122 self .vocabulary_size = vocabulary_size
124123 self .sequence_length = sequence_length
125124
125+ if dtype is not None :
126+ try :
127+ self .dtype_policy = keras .dtype_policies .get (dtype )
128+ # Before Keras 3.2, there is no `keras.dtype_policies.get`.
129+ except AttributeError :
130+ if isinstance (dtype , keras .DTypePolicy ):
131+ dtype = dtype .name
132+ self .dtype_policy = keras .DTypePolicy (dtype )
133+
126134 def get_config (self ):
127135 config = super ().get_config ()
128136 config .update (
Original file line number Diff line number Diff line change 1616from keras_nlp .src .layers .modeling .reversible_embedding import (
1717 ReversibleEmbedding ,
1818)
19- from keras_nlp .src .models .backbone import Backbone
2019from keras_nlp .src .models .t5 .t5_layer_norm import T5LayerNorm
2120from keras_nlp .src .models .t5 .t5_transformer_layer import T5TransformerLayer
2221
2322
24- class T5XXLTextEncoder (Backbone ):
23+ class T5XXLTextEncoder (keras . Model ):
2524 def __init__ (
2625 self ,
2726 vocabulary_size ,
@@ -111,7 +110,6 @@ def __init__(
111110 "encoder_padding_mask" : encoder_padding_mask_input ,
112111 },
113112 outputs = encoder_output ,
114- dtype = dtype ,
115113 ** kwargs ,
116114 )
117115
@@ -128,6 +126,15 @@ def __init__(
128126 self .layer_norm_epsilon = layer_norm_epsilon
129127 self .tie_embedding_weights = tie_embedding_weights
130128
129+ if dtype is not None :
130+ try :
131+ self .dtype_policy = keras .dtype_policies .get (dtype )
132+ # Before Keras 3.2, there is no `keras.dtype_policies.get`.
133+ except AttributeError :
134+ if isinstance (dtype , keras .DTypePolicy ):
135+ dtype = dtype .name
136+ self .dtype_policy = keras .DTypePolicy (dtype )
137+
131138 def get_config (self ):
132139 config = super ().get_config ()
133140 config .update (
You can’t perform that action at this time.
0 commit comments