@@ -12,6 +12,7 @@ Mix.install([
1212
1313Nx .global_default_backend (EXLA .Backend )
1414Nx .Defn .global_default_options (compiler: EXLA )
15+
1516```
1617
1718## Dataset
@@ -54,7 +55,6 @@ In metric learning, we don’t hand the model lone examples, instead we show it
5455class_idx_to_train_idxs =
5556 bin
5657 |> Nx .from_binary (type)
57- |> Nx .reshape (shape)
5858 |> Nx .to_flat_list ()
5959 |> Enum .with_index ()
6060 |> Enum .group_by (& elem (&1 , 0 ), fn {_ , i} -> i end )
@@ -64,7 +64,6 @@ class_idx_to_train_idxs =
6464class_idx_to_test_idxs =
6565 bin
6666 |> Nx .from_binary (type)
67- |> Nx .reshape (shape)
6867 |> Nx .to_flat_list ()
6968 |> Enum .with_index ()
7069 |> Enum .group_by (& elem (&1 , 0 ), fn {_ , i} -> i end )
@@ -80,22 +79,15 @@ With the index in place, the training loop draws one anchor and one sibling set
8079``` elixir
8180defmodule GetImages do
8281 def batch (train_images, class_idx_to_train_idxs) do
83- anchors_idx = Enum .map (0 .. 9 , fn class ->
84- indices = class_idx_to_train_idxs[class]
85- Enum .random (indices)
86- end )
87-
88- positives_idx = Enum .map (0 .. 9 , fn class ->
89- indices = class_idx_to_train_idxs[class]
90- # Exclude the anchor from possible positives
91- anchor_idx = Enum .at (anchors_idx, class)
92- indices
93- |> Enum .filter (fn idx -> idx != anchor_idx end )
94- |> Enum .random ()
95- end )
82+ {anchors_idx, positives_idx} =
83+ Enum .unzip (for class <- 0 .. 9 do
84+ [a, p] = Enum .take_random (class_idx_to_train_idxs[class], 2 )
85+ {a, p}
86+ end )
9687
9788 anchors = Nx .take (train_images, Nx .tensor (anchors_idx)) |> Nx .reshape ({10 , 32 , 32 , 3 })
9889 positives = Nx .take (train_images, Nx .tensor (positives_idx)) |> Nx .reshape ({10 , 32 , 32 , 3 })
90+
9991 {anchors, positives}
10092 end
10193end
@@ -155,13 +147,9 @@ defmodule MetricModel do
155147 end
156148
157149 defn normalize (x) do
158- den =
159- Nx .multiply (x, x)
160- |> Nx .sum (axes: [- 1 ], keep_axes: true )
161- |> Nx .sqrt ()
162- den = Nx .max (den, 1.0e-7 )
163- Nx .divide (x, den)
164- end
150+ norm = Nx .LinAlg .norm (x, axes: [- 1 ], keep_axes: true )
151+ Nx .divide (x, norm)
152+ end
165153
166154end
167155```
@@ -211,7 +199,7 @@ The training loop then uses that loss to nudge parameters, pulling same-class ve
211199defmodule MetricLearning do
212200 import Nx .Defn
213201 require Logger
214-
202+
215203 defn objective_fn (predict_fn, params, {anchor, positive}) do
216204 %{prediction: anchor_embeddings} = predict_fn .(params, %{" input" => anchor})
217205 %{prediction: positive_embeddings} = predict_fn .(params, %{" input" => positive})
@@ -304,43 +292,40 @@ near_neighbors_per_example = 10
304292
305293embeddings = Nx .rename (embeddings, [nil , nil ])
306294gram_matrix = Nx .dot (embeddings, Nx .transpose (embeddings))
295+
307296{_vals , neighbors} = Nx .top_k (gram_matrix, k: near_neighbors_per_example + 1 )
297+
308298:ok
309299```
310300
311- To visually inspect how well our embeddings capture similarity, we create a collage for each of the ten classes. For each class, we randomly pick one example and place it in the first column. Then, in the next ten columns, we display its ten closest neighbors so you can see which images the network considers its nearest matches.
301+ To visually inspect how well our embeddings capture similarity, we create a collage for each of the ten classes. For each class, we pick the first example in each class and place it in the first column. Then, in the next ten columns, we display its ten closest neighbors to see which images the network considers its nearest matches.
312302
313303``` elixir
314304# take first image of each class
315305example_per_class_idx =
316306 0 .. 9
317307 |> Enum .map (fn class_idx ->
318- class_idx_to_test_idxs[class_idx] |> Enum .random ( )
308+ class_idx_to_test_idxs[class_idx] |> Enum .at ( 0 )
319309 end )
320310 |> Nx .tensor (type: {:s , 64 })
321311
322312# take nearest neighbors for each example
323313neighbors_for_samples = Nx .take (neighbors, example_per_class_idx, axis: 0 )
324314
325- # show the ten closest images
326- images = for row_idx <- 0 .. 9 do
327- neighbour_idxs =
328- neighbors_for_samples
329- |> Nx .slice ([row_idx, 0 ], [1 , near_neighbors_per_example])
330- |> Nx .squeeze ()
315+ neighbour_idxs =
316+ neighbors_for_samples
331317 |> Nx .to_flat_list ()
332318
333- images =
334- for idx <- neighbour_idxs do
335- test_images
336- |> Nx .take (Nx .tensor ([idx]), axis: 0 )
337- |> Nx .squeeze ()
338- |> Nx .transpose (axes: [:width , :height , :channels ])
339- |> create_kino_image .()
340- end
319+ images =
320+ for idx <- neighbour_idxs do
321+ test_images[idx]
322+ |> Nx .squeeze ()
323+ |> Nx .transpose (axes: [:width , :height , :channels ])
324+ |> create_kino_image .()
325+ end
326+
327+ Kino .render (Kino .Layout .grid (images, columns: 11 ))
341328
342- Kino .render (Kino .Layout .grid (images, columns: near_neighbors_per_example))
343- end
344329:ok
345330```
346331
0 commit comments