Skip to content

Conversation

@goodhamgupta
Copy link

@goodhamgupta goodhamgupta commented Dec 7, 2021

Hi everyone,

Thanks for this excellent library! This PR aims to add an example for training the KuzushijiMNIST dataset using Axon(#47). The file is almost an exact replica of mnist.exs, with the only difference being the Scidata version used and how we transform the images and labels.

Training the model locally on my laptop, I get the following metrics:

--------------------------------------------------
                      Model
==================================================
 Layer                    Shape        Parameters
==================================================
 input_8 ( input )        {nil, 784}   0
 dense_11 ( dense )       {nil, 128}   100480
 relu_12 ( relu )         {nil, 128}   0
 dropout_13 ( dropout )   {nil, 128}   0
 dense_16 ( dense )       {nil, 10}    1290
 softmax_17 ( softmax )   {nil, 10}    0
--------------------------------------------------


Training Model

Epoch: 4, Batch: 1750, Loss: 0.52946 Accuracy: 0.97277

 Testing Model

Epoch: 0, Batch: 100,  Accuracy: 0.98372

Thanks!

@goodhamgupta goodhamgupta reopened this Dec 7, 2021
@goodhamgupta
Copy link
Author

I accidentally closed this PR. Sorry for the notification noise! 😓

@seanmor5
Copy link
Contributor

seanmor5 commented Dec 8, 2021

Thank you @goodhamgupta for the PR! I want this example to get merged, my concern is that it is too similar to the original MNIST example and does not demonstrate or highlight a unique feature of Axon. I think we can fix that though :)

I think we can adapt one of the examples in: https://keras.io/examples/ using this dataset and Axon. At a quick glance, https://keras.io/examples/vision/metric_learning/ and https://keras.io/examples/vision/near_dup_search/ and anything in https://keras.io/examples/generative/ seem very do-able for now. Or if you find another example you'd like to do, but you aren't sure how to make it work in Axon, I'd be glad to help walk you through putting the example together in this PR! Just let me know what looks interesting and we can get something up and running

Thanks again for the contribution!

@goodhamgupta
Copy link
Author

Thank you for the kind review @seanmor5! I fully agree that this PR is a copy of the MNIST example for now and doesn't provide any new insight into the features of Axon.

I want to implement the metric learning example for similarity search(https://keras.io/examples/vision/metric_learning/). Would you prefer I close this PR for now and open a new PR once it's ready, or should I continue on this PR?

@seanmor5
Copy link
Contributor

seanmor5 commented Dec 8, 2021

You can continue working in this PR! Let me know if you run into any problems and I'll be glad to help!

@goodhamgupta
Copy link
Author

Hi @seanmor5,

Sorry for pinging you directly, but if you could provide any hints on how to proceed from my previous comment it would be super helpful! 😅

|> Axon.global_avg_pool()
|> Axon.dense(8)

# |> l2_normalize()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To apply this:

|> Axon.nx(&l2_normalize/1)

Comment on lines +116 to +121
defn batch_step(model, optim, real_images, state) do
iter = state[:iteration]
params = state[:model_state]
IO.puts(iter)
# Add code to compute cosine similarity for metric learning
end
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The example is something similar to this:

defn objective_fn(model, params, {anchor, positive}) do
  anchor_embeddings = Axon.predict(model, params, anchor)
  positive_embeddings = Axon.predict(model, params, positive)
  # This is a pair-wise dot product in Nx
  similarities = Nx.dot(anchor_embeddings, [1], positive_embeddings, [1])
  temperature = 0.2
  similarities = similarities / temperature
  sparse_labels = Nx.iota({@num_classes})
  Axon.Losses.categorical_cross_entropy(sparse_labels, similarities, reduction: :mean, sparse: true, from_logits: true)
end

defn batch_step(model, optim, {anchor, positive}, state) do
  # Compute gradient of objective defined above
  {loss, gradients}= value_and_grad(state.model_state, &objective_fn(model, &1, {anchor, positive}))
  # Step with optim
  {updates, new_optimizer_state} = optim.(state.optimizer_state, state.model_state, gradients)
  new_params = Axon.Updates.apply_updates(state.model_state, updates)
  %{state |
         model_state: new_params,
         optimizer_state: new_optimizer_state,
         iteration: state.iteration + 1,
         loss: loss}
end

Then I would have something like:

defp make_batch_step(model, optimizer) do
  &batch_step(model, optimizer, &1, &2)
end

Then to train, you just do something like:

model = build_model()
{optim_init_fn, optim_update_fn}= Axon.Optimizers.adam(1.0e-3)
init_loop = fn -> init(model, optim_init_fn) end
batch_step = make_batch_step(model, optim_update_fn)

batch_step
|> Axon.Loop.loop(init_loop)
|> Axon.Loop.run(...)

You'll just need to work on getting the data in the {anchor, positive} pairs then.

@seanmor5
Copy link
Contributor

seanmor5 commented Sep 6, 2022

Hi @goodhamgupta I am closing this for now as stale, if you want to re-open and continue work feel free and I will do my best to guide you in the right direction :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants