diff --git a/tensorflow_similarity/models/similarity_model.py b/tensorflow_similarity/models/similarity_model.py index 69677a37..3b092853 100644 --- a/tensorflow_similarity/models/similarity_model.py +++ b/tensorflow_similarity/models/similarity_model.py @@ -87,8 +87,46 @@ class SimilarityModel(tf.keras.Model): """ def __init__(self, *args, **kwargs): + self.batch_compute_gradient_portion = float(kwargs.pop('batch_compute_gradient_portion', 1)) + self.batch_random_permutation = bool(kwargs.pop('batch_random_permutation', False)) + + assert 0. < self.batch_compute_gradient_portion <= 1. + assert self.batch_random_permutation in [True, False] + super().__init__(*args, **kwargs) + def train_step(self, data): + x, y, sample_weight = tf.keras.utils.unpack_x_y_sample_weight(data) + + if self.batch_random_permutation: + indices = tf.range(start=0, limit=tf.shape(x)[0], dtype=tf.int32) + shuffled_indices = tf.random.shuffle(indices) + + x = tf.gather(x, shuffled_indices) + y = tf.gather(y, shuffled_indices) + if sample_weight is not None: + sample_weight = tf.gather(sample_weight, shuffled_indices) + + l = tf.cast(tf.shape(x)[0], tf.float32) + k = tf.cast(self.batch_compute_gradient_portion * l, tf.int32) + + # Run forward pass. + y_pred_without_gradient = self(x[k:], training=True) + + with tf.GradientTape() as tape: + y_pred_with_gradient = self(x[:k], training=True) + + y_pred = tf.concat([y_pred_with_gradient, y_pred_without_gradient], axis=0) + + loss = self.compute_loss(x, y, y_pred, sample_weight) + + self._validate_target_and_loss(y, loss) + + # Run backwards pass. + self.optimizer.minimize(loss, self.trainable_variables, tape=tape) + + return self.compute_metrics(x, y, y_pred, sample_weight) + def compile( self, optimizer: Optimizer | str | Mapping | Sequence = "rmsprop",