@@ -416,6 +416,50 @@ def test_lrn():
416416 verify_lrn ((5 , 5 , 5 , 5 ), 3 , 'float32' )
417417 verify_lrn ((5 , 5 , 5 , 5 ), 3 , 'float32' , alpha = 0.0002 , beta = 0.5 , bias = 2.0 )
418418
419+
420+ def verify_instance_norm (shape , axis = 1 ):
421+
422+ def _get_python_instance_norm (x , gamma , beta , epsilon = 1e-5 ):
423+ dims_x = len (x .shape )
424+ axis = tuple (range (2 , dims_x ))
425+ mean = np .mean (x , axis = axis , keepdims = True )
426+ var = np .var (x , axis = axis , keepdims = True )
427+ dim_ones = (1 ,) * (dims_x - 2 )
428+ gamma = gamma .reshape (- 1 , * dim_ones )
429+ beta = beta .reshape (- 1 , * dim_ones )
430+ return gamma * (x - mean ) / np .sqrt (var + epsilon ) + beta
431+
432+ x = np .random .randn (* shape ).astype (np .float32 )
433+ gamma = np .random .randn (shape [1 ]).astype (np .float32 )
434+ beta = np .random .randn (shape [1 ]).astype (np .float32 )
435+ epsilon = 1e-5
436+ y = _get_python_instance_norm (x , gamma , beta , epsilon ).astype (np .float32 )
437+
438+ node = onnx .helper .make_node (
439+ 'InstanceNormalization' ,
440+ inputs = ['x' , 'gamma' , 'beta' ],
441+ outputs = ['y' ],
442+ epsilon = epsilon ,
443+ )
444+ graph = helper .make_graph ([node ],
445+ "instance_norm_test" ,
446+ inputs = [helper .make_tensor_value_info ("x" , TensorProto .FLOAT , list (shape )),
447+ helper .make_tensor_value_info ("gamma" , TensorProto .FLOAT , (shape [1 ],)),
448+ helper .make_tensor_value_info ("beta" , TensorProto .FLOAT , (shape [1 ],))],
449+ outputs = [helper .make_tensor_value_info ("y" , TensorProto .FLOAT , list (shape ))])
450+ model = helper .make_model (graph , producer_name = 'instance_norm_test' )
451+ for target , ctx in ctx_list ():
452+ tvm_out = get_tvm_output (model , [x , gamma , beta ], target , ctx , shape , 'float32' )
453+ tvm .testing .assert_allclose (y , tvm_out , rtol = 1e-5 , atol = 1e-5 )
454+
455+
456+ def test_instance_norm ():
457+ verify_instance_norm ((2 , 3 , 4 , 5 ))
458+ verify_instance_norm ((32 , 64 , 80 , 64 ))
459+ verify_instance_norm ((8 , 6 , 5 ))
460+ verify_instance_norm ((8 , 7 , 6 , 5 , 4 ))
461+
462+
419463def _test_upsample_nearest ():
420464 scale = 2
421465 in_shape = (1 , 1 , 3 , 3 )
@@ -1270,6 +1314,7 @@ def test_erf():
12701314 test_matmul ()
12711315 test_gather ()
12721316 test_lrn ()
1317+ test_instance_norm ()
12731318 test_upsample ()
12741319 test_forward_min ()
12751320 test_forward_max ()
0 commit comments