|
5 | 5 | m = Dense(10, 5)
|
6 | 6 | @test_throws DimensionMismatch outputsize(m, (5, 2)) == (5, 1)
|
7 | 7 | @test outputsize(m, (10,); padbatch=true) == (5, 1)
|
| 8 | + @test outputsize(m, (10,)) == (5,) |
| 9 | + @test outputsize(m, (10, 6, 7)) == (5, 6, 7) |
8 | 10 |
|
9 | 11 | m = Chain(Dense(10, 8, σ), Dense(8, 5), Dense(5, 2))
|
10 | 12 | @test outputsize(m, (10,); padbatch=true) == (2, 1)
|
|
41 | 43 | @test outputsize(m, (10, 10, 3, 1)) == (10, 10, 19, 1)
|
42 | 44 | end
|
43 | 45 |
|
| 46 | +@testset "embeddings" begin |
| 47 | + # Here outputsize expects indices, not one-hot representation: |
| 48 | + m = Embedding(3 => 4) |
| 49 | + @test outputsize(m, (3, 7)) == (4, 3, 7) == size(m(rand(1:3, 3, 7))) |
| 50 | + @test outputsize(m, (5, 6, 7)) == (4, 5, 6, 7) == size(m(rand(1:3, 5, 6, 7))) |
| 51 | + |
| 52 | + m = Chain(x -> Flux.onehotbatch(x, 1:5), Embedding(5 => 7)) |
| 53 | + @test size(m([3,4])) == (7, 2) |
| 54 | + @test outputsize(m, (2,)) == (7, 2) |
| 55 | + # This works because Flux.onehotbatch([nil, nil], 1:5) makes a 5×2 OneHotMatrix |
| 56 | + # But e.g. Flux.onehotbatch([nil, nil], 'a':'e') will not work. |
| 57 | +end |
| 58 | + |
44 | 59 | @testset "multiple inputs" begin
|
45 | 60 | m = Parallel(vcat, Dense(2, 4, relu), Dense(3, 6, relu))
|
46 | 61 | @test outputsize(m, (2,), (3,)) == (10,)
|
|
0 commit comments