@@ -212,12 +212,10 @@ def test_inner_composite(mode):
212212 y16 = op (n_steps , x16 )
213213 assert y16 .type .dtype == "float16"
214214
215- fn32 = function ([n_steps , x16 ], y16 , mode = mode )
216- np .testing .assert_allclose (
217- fn32 (n_steps = 9 , x16 = np .array (4.73 , dtype = "float16" )),
218- 4.73 + 9 ,
219- rtol = 1e-3 ,
220- )
215+ fn16 = function ([n_steps , x16 ], y16 , mode = mode )
216+ out16 = fn16 (n_steps = 9 , x16 = np .array (4.73 , dtype = "float16" ))
217+ assert out16 .dtype == "float16"
218+ assert np .isnan (out16 )
221219
222220
223221@mode
@@ -243,8 +241,10 @@ def test_inner_loop(mode):
243241 y16 = outer_loop_op (n_steps , x16 , n_steps )
244242 assert y16 .type .dtype == "float16"
245243
246- fn32 = function ([n_steps , x16 ], y16 , mode = mode )
244+ fn16 = function ([n_steps , x16 ], y16 , mode = mode )
245+ out16 = fn16 (n_steps = 3 , x16 = np .array (2.5 , dtype = "float16" ))
246+ assert out16 .dtype == "float16"
247247 np .testing .assert_allclose (
248- fn32 ( n_steps = 3 , x16 = np . array ( 2.5 , dtype = "float16" )) ,
248+ out16 ,
249249 3 ** 2 + 2.5 ,
250250 )
0 commit comments