Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 89 additions & 7 deletions examples/vision/captcha_ocr.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@
## Setup
"""

import os

os.environ["KERAS_BACKEND"] = "tensorflow"

import os
import numpy as np
import matplotlib.pyplot as plt
Expand All @@ -30,9 +34,8 @@
from collections import Counter

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

import keras
from keras import layers

"""
## Load the data: [Captcha Images](https://www.kaggle.com/fournierp/captcha-version-2-images)
Expand Down Expand Up @@ -180,10 +183,64 @@ def encode_single_sample(img_path, label):
"""


def ctc_batch_cost(y_true, y_pred, input_length, label_length):
label_length = tf.cast(tf.squeeze(label_length, axis=-1), tf.int32)
input_length = tf.cast(tf.squeeze(input_length, axis=-1), tf.int32)
sparse_labels = tf.cast(ctc_label_dense_to_sparse(y_true, label_length), tf.int32)

y_pred = tf.math.log(tf.transpose(y_pred, perm=[1, 0, 2]) + keras.backend.epsilon())

return tf.expand_dims(
tf.compat.v1.nn.ctc_loss(
inputs=y_pred, labels=sparse_labels, sequence_length=input_length
),
1,
)


def ctc_label_dense_to_sparse(labels, label_lengths):
label_shape = tf.shape(labels)
num_batches_tns = tf.stack([label_shape[0]])
max_num_labels_tns = tf.stack([label_shape[1]])

def range_less_than(old_input, current_input):
return tf.expand_dims(tf.range(tf.shape(old_input)[1]), 0) < tf.fill(
max_num_labels_tns, current_input
)

init = tf.cast(tf.fill([1, label_shape[1]], 0), tf.bool)
dense_mask = tf.compat.v1.scan(
range_less_than, label_lengths, initializer=init, parallel_iterations=1
)
dense_mask = dense_mask[:, 0, :]

label_array = tf.reshape(
tf.tile(tf.range(0, label_shape[1]), num_batches_tns), label_shape
)
label_ind = tf.compat.v1.boolean_mask(label_array, dense_mask)

batch_array = tf.transpose(
tf.reshape(
tf.tile(tf.range(0, label_shape[0]), max_num_labels_tns),
tf.reverse(label_shape, [0]),
)
)
batch_ind = tf.compat.v1.boolean_mask(batch_array, dense_mask)
indices = tf.transpose(
tf.reshape(tf.concat([batch_ind, label_ind], axis=0), [2, -1])
)

vals_sparse = tf.compat.v1.gather_nd(labels, indices)

return tf.SparseTensor(
tf.cast(indices, tf.int64), vals_sparse, tf.cast(label_shape, tf.int64)
)


class CTCLayer(layers.Layer):
def __init__(self, name=None):
super().__init__(name=name)
self.loss_fn = keras.backend.ctc_batch_cost
self.loss_fn = ctc_batch_cost

def call(self, y_true, y_pred):
# Compute the training-time loss value and add it
Expand Down Expand Up @@ -272,7 +329,8 @@ def build_model():
"""


epochs = 100
# TODO restore epoch count.
epochs = 2
early_stopping_patience = 10
# Add early stopping
early_stopping = keras.callbacks.EarlyStopping(
Expand All @@ -296,9 +354,33 @@ def build_model():
"""


def ctc_decode(y_pred, input_length, greedy=True, beam_width=100, top_paths=1):
input_shape = tf.shape(y_pred)
num_samples, num_steps = input_shape[0], input_shape[1]
y_pred = tf.math.log(tf.transpose(y_pred, perm=[1, 0, 2]) + keras.backend.epsilon())
input_length = tf.cast(input_length, tf.int32)

if greedy:
(decoded, log_prob) = tf.nn.ctc_greedy_decoder(
inputs=y_pred, sequence_length=input_length
)
else:
(decoded, log_prob) = tf.compat.v1.nn.ctc_beam_search_decoder(
inputs=y_pred,
sequence_length=input_length,
beam_width=beam_width,
top_paths=top_paths,
)
decoded_dense = []
for st in decoded:
st = tf.SparseTensor(st.indices, st.values, (num_samples, num_steps))
decoded_dense.append(tf.sparse.to_dense(sp_input=st, default_value=-1))
return (decoded_dense, log_prob)


# Get the prediction model by extracting layers till the output layer
prediction_model = keras.models.Model(
model.get_layer(name="image").input, model.get_layer(name="dense2").output
model.input[0], model.get_layer(name="dense2").output
)
prediction_model.summary()

Expand All @@ -307,7 +389,7 @@ def build_model():
def decode_batch_predictions(pred):
input_len = np.ones(pred.shape[0]) * pred.shape[1]
# Use greedy search. For complex tasks, you can use beam search
results = keras.backend.ctc_decode(pred, input_length=input_len, greedy=True)[0][0][
results = ctc_decode(pred, input_length=input_len, greedy=True)[0][0][
:, :max_length
]
# Iterate over the results and get back the text
Expand Down
Binary file modified examples/vision/img/captcha_ocr/captcha_ocr_13_0.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading