-
Notifications
You must be signed in to change notification settings - Fork 115
Add KuzushijiMNIST example #175
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add KuzushijiMNIST example #175
Conversation
|
I accidentally closed this PR. Sorry for the notification noise! 😓 |
|
Thank you @goodhamgupta for the PR! I want this example to get merged, my concern is that it is too similar to the original MNIST example and does not demonstrate or highlight a unique feature of Axon. I think we can fix that though :) I think we can adapt one of the examples in: https://keras.io/examples/ using this dataset and Axon. At a quick glance, https://keras.io/examples/vision/metric_learning/ and https://keras.io/examples/vision/near_dup_search/ and anything in https://keras.io/examples/generative/ seem very do-able for now. Or if you find another example you'd like to do, but you aren't sure how to make it work in Axon, I'd be glad to help walk you through putting the example together in this PR! Just let me know what looks interesting and we can get something up and running Thanks again for the contribution! |
|
Thank you for the kind review @seanmor5! I fully agree that this PR is a copy of the MNIST example for now and doesn't provide any new insight into the features of Axon. I want to implement the metric learning example for similarity search(https://keras.io/examples/vision/metric_learning/). Would you prefer I close this PR for now and open a new PR once it's ready, or should I continue on this PR? |
|
You can continue working in this PR! Let me know if you run into any problems and I'll be glad to help! |
|
Hi @seanmor5, Sorry for pinging you directly, but if you could provide any hints on how to proceed from my previous comment it would be super helpful! 😅 |
| |> Axon.global_avg_pool() | ||
| |> Axon.dense(8) | ||
|
|
||
| # |> l2_normalize() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To apply this:
|> Axon.nx(&l2_normalize/1)| defn batch_step(model, optim, real_images, state) do | ||
| iter = state[:iteration] | ||
| params = state[:model_state] | ||
| IO.puts(iter) | ||
| # Add code to compute cosine similarity for metric learning | ||
| end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The example is something similar to this:
defn objective_fn(model, params, {anchor, positive}) do
anchor_embeddings = Axon.predict(model, params, anchor)
positive_embeddings = Axon.predict(model, params, positive)
# This is a pair-wise dot product in Nx
similarities = Nx.dot(anchor_embeddings, [1], positive_embeddings, [1])
temperature = 0.2
similarities = similarities / temperature
sparse_labels = Nx.iota({@num_classes})
Axon.Losses.categorical_cross_entropy(sparse_labels, similarities, reduction: :mean, sparse: true, from_logits: true)
end
defn batch_step(model, optim, {anchor, positive}, state) do
# Compute gradient of objective defined above
{loss, gradients}= value_and_grad(state.model_state, &objective_fn(model, &1, {anchor, positive}))
# Step with optim
{updates, new_optimizer_state} = optim.(state.optimizer_state, state.model_state, gradients)
new_params = Axon.Updates.apply_updates(state.model_state, updates)
%{state |
model_state: new_params,
optimizer_state: new_optimizer_state,
iteration: state.iteration + 1,
loss: loss}
endThen I would have something like:
defp make_batch_step(model, optimizer) do
&batch_step(model, optimizer, &1, &2)
end
Then to train, you just do something like:
model = build_model()
{optim_init_fn, optim_update_fn}= Axon.Optimizers.adam(1.0e-3)
init_loop = fn -> init(model, optim_init_fn) end
batch_step = make_batch_step(model, optim_update_fn)
batch_step
|> Axon.Loop.loop(init_loop)
|> Axon.Loop.run(...)You'll just need to work on getting the data in the {anchor, positive} pairs then.
|
Hi @goodhamgupta I am closing this for now as stale, if you want to re-open and continue work feel free and I will do my best to guide you in the right direction :) |
Hi everyone,
Thanks for this excellent library! This PR aims to add an example for training the KuzushijiMNIST dataset using Axon(#47). The file is almost an exact replica of
mnist.exs, with the only difference being theScidataversion used and how we transform the images and labels.Training the model locally on my laptop, I get the following metrics:
Thanks!