Skip to content

Conversation

@sitammeur
Copy link
Contributor

This PR updates the OCR model for reading Captchas Keras 3.0 example [TF Only Backend]. Many TF ops are replaced with corresponding Keras ops.

For example, here is the notebook link provided:
https://colab.research.google.com/drive/1vCDb45wLmSI3iBI2_BfDDYDSztgxZ4Qp?usp=sharing

cc: @fchollet @divyashreepathihalli

The following describes the Git difference for the changed files:

Changes:
diff --git a/examples/vision/captcha_ocr.py b/examples/vision/captcha_ocr.py
index 3a2b8e96..06115b0f 100644
--- a/examples/vision/captcha_ocr.py
+++ b/examples/vision/captcha_ocr.py
@@ -35,6 +35,7 @@ from collections import Counter
 
 import tensorflow as tf
 import keras
+from keras import ops
 from keras import layers
 
 """
@@ -109,9 +110,9 @@ def split_data(images, labels, train_size=0.9, shuffle=True):
     # 1. Get the total size of the dataset
     size = len(images)
     # 2. Make an indices array and shuffle it, if required
-    indices = np.arange(size)
+    indices = ops.arange(size)
     if shuffle:
-        np.random.shuffle(indices)
+        keras.random.shuffle(indices)
     # 3. Get the size of training samples
     train_samples = int(size * train_size)
     # 4. Split data into training and validation sets
@@ -132,10 +133,10 @@ def encode_single_sample(img_path, label):
     # 3. Convert to float32 in [0, 1] range
     img = tf.image.convert_image_dtype(img, tf.float32)
     # 4. Resize to the desired size
-    img = tf.image.resize(img, [img_height, img_width])
+    img = ops.image.resize(img, [img_height, img_width])
     # 5. Transpose the image because we want the time
     # dimension to correspond to the width of the image.
-    img = tf.transpose(img, perm=[1, 0, 2])
+    img = ops.transpose(img, axes=[1, 0, 2])
     # 6. Map the characters in label to numbers
     label = char_to_num(tf.strings.unicode_split(label, input_encoding="UTF-8"))
     # 7. Return a dict as our model is expecting two inputs
@@ -184,13 +185,13 @@ plt.show()
 
 
 def ctc_batch_cost(y_true, y_pred, input_length, label_length):
-    label_length = tf.cast(tf.squeeze(label_length, axis=-1), tf.int32)
-    input_length = tf.cast(tf.squeeze(input_length, axis=-1), tf.int32)
-    sparse_labels = tf.cast(ctc_label_dense_to_sparse(y_true, label_length), tf.int32)
+    label_length = ops.cast(ops.squeeze(label_length, axis=-1), dtype="int32")
+    input_length = ops.cast(ops.squeeze(input_length, axis=-1), dtype="int32")
+    sparse_labels = ops.cast(ctc_label_dense_to_sparse(y_true, label_length), dtype="int32")
 
-    y_pred = tf.math.log(tf.transpose(y_pred, perm=[1, 0, 2]) + keras.backend.epsilon())
+    y_pred = ops.log(ops.transpose(y_pred, axes=[1, 0, 2]) + keras.backend.epsilon())
 
-    return tf.expand_dims(
+    return ops.expand_dims(
         tf.compat.v1.nn.ctc_loss(
             inputs=y_pred, labels=sparse_labels, sequence_length=input_length
         ),
@@ -199,41 +200,41 @@ def ctc_batch_cost(y_true, y_pred, input_length, label_length):
 
 
 def ctc_label_dense_to_sparse(labels, label_lengths):
-    label_shape = tf.shape(labels)
-    num_batches_tns = tf.stack([label_shape[0]])
-    max_num_labels_tns = tf.stack([label_shape[1]])
+    label_shape = ops.shape(labels)
+    num_batches_tns = ops.stack([label_shape[0]])
+    max_num_labels_tns = ops.stack([label_shape[1]])
 
     def range_less_than(old_input, current_input):
-        return tf.expand_dims(tf.range(tf.shape(old_input)[1]), 0) < tf.fill(
+        return ops.expand_dims(ops.arange(ops.shape(old_input)[1]), 0) < tf.fill(
             max_num_labels_tns, current_input
         )
 
-    init = tf.cast(tf.fill([1, label_shape[1]], 0), tf.bool)
+    init = ops.cast(tf.fill([1, label_shape[1]], 0), dtype="bool")
     dense_mask = tf.compat.v1.scan(
         range_less_than, label_lengths, initializer=init, parallel_iterations=1
     )
     dense_mask = dense_mask[:, 0, :]
 
-    label_array = tf.reshape(
-        tf.tile(tf.range(0, label_shape[1]), num_batches_tns), label_shape
+    label_array = ops.reshape(
+        ops.tile(ops.arange(0, label_shape[1]), num_batches_tns), label_shape
     )
     label_ind = tf.compat.v1.boolean_mask(label_array, dense_mask)
 
-    batch_array = tf.transpose(
-        tf.reshape(
-            tf.tile(tf.range(0, label_shape[0]), max_num_labels_tns),
+    batch_array = ops.transpose(
+        ops.reshape(
+            ops.tile(ops.arange(0, label_shape[0]), max_num_labels_tns),
             tf.reverse(label_shape, [0]),
         )
     )
     batch_ind = tf.compat.v1.boolean_mask(batch_array, dense_mask)
-    indices = tf.transpose(
-        tf.reshape(tf.concat([batch_ind, label_ind], axis=0), [2, -1])
+    indices = ops.transpose(
+        ops.reshape(ops.concatenate([batch_ind, label_ind], axis=0), [2, -1])
     )
 
     vals_sparse = tf.compat.v1.gather_nd(labels, indices)
 
     return tf.SparseTensor(
-        tf.cast(indices, tf.int64), vals_sparse, tf.cast(label_shape, tf.int64)
+        ops.cast(indices, dtype="int64"), vals_sparse, ops.cast(label_shape, dtype="int64")
     )
 
 
@@ -245,12 +246,12 @@ class CTCLayer(layers.Layer):
     def call(self, y_true, y_pred):
         # Compute the training-time loss value and add it
         # to the layer using `self.add_loss()`.
-        batch_len = tf.cast(tf.shape(y_true)[0], dtype="int64")
-        input_length = tf.cast(tf.shape(y_pred)[1], dtype="int64")
-        label_length = tf.cast(tf.shape(y_true)[1], dtype="int64")
+        batch_len = ops.cast(tf.shape(y_true)[0], dtype="int64")
+        input_length = ops.cast(tf.shape(y_pred)[1], dtype="int64")
+        label_length = ops.cast(tf.shape(y_true)[1], dtype="int64")
 
-        input_length = input_length * tf.ones(shape=(batch_len, 1), dtype="int64")
-        label_length = label_length * tf.ones(shape=(batch_len, 1), dtype="int64")
+        input_length = input_length * ops.ones(shape=(batch_len, 1), dtype="int64")
+        label_length = label_length * ops.ones(shape=(batch_len, 1), dtype="int64")
 
         loss = self.loss_fn(y_true, y_pred, input_length, label_length)
         self.add_loss(loss)
@@ -355,10 +356,10 @@ and try the demo on [Hugging Face Spaces](https://huggingface.co/spaces/keras-io
 
 
 def ctc_decode(y_pred, input_length, greedy=True, beam_width=100, top_paths=1):
-    input_shape = tf.shape(y_pred)
+    input_shape = ops.shape(y_pred)
     num_samples, num_steps = input_shape[0], input_shape[1]
-    y_pred = tf.math.log(tf.transpose(y_pred, perm=[1, 0, 2]) + keras.backend.epsilon())
-    input_length = tf.cast(input_length, tf.int32)
+    y_pred = ops.log(ops.transpose(y_pred, axes=[1, 0, 2]) + keras.backend.epsilon())
+    input_length = ops.cast(input_length, dtype="int32")
 
     if greedy:
         (decoded, log_prob) = tf.nn.ctc_greedy_decoder(
(END)

Copy link
Contributor

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good, thank you. Please add the generated files.

@sitammeur
Copy link
Contributor Author

Looks good, thank you. Please add the generated files.

The files in .md and .ipynb have been added. I also attempted the same code approach on a handwritten OCR sample. For that example, model training with Keras 3 was also successful, so I will be making that PR shortly.

Copy link
Contributor

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM -- Thanks!

@fchollet fchollet merged commit 14136a8 into keras-team:master Mar 13, 2024
sitammeur added a commit to sitammeur/keras-io that referenced this pull request May 30, 2024
…ras-team#1788)

* replaced the tf ops with keras ops

* final formatting done

* .md and .ipynb files are added

* minor changes done
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants