@@ -21,11 +21,11 @@ def benchmark(function, args, num_runs):
2121
2222def test_vs_existing ():
2323 def new_ (scale ):
24- fake_tensor = torch .randint (2 ** 8 - 1 , (1 , scale ,scale ), dtype = torch .uint8 ).cuda ()
24+ fake_tensor = torch .randint (2 ** 8 , (1 , scale ,scale ), dtype = torch .uint8 ).cuda ()
2525 packed = pack (fake_tensor , 4 , dim = 1 )
2626 unpacked = unpack (packed , 4 , dim = 1 )
2727 def old_ (scale ):
28- fake_tensor = torch .randint (2 ** 8 - 1 , (1 , scale ,scale ), dtype = torch .uint8 ).cuda ()
28+ fake_tensor = torch .randint (2 ** 8 , (1 , scale ,scale ), dtype = torch .uint8 ).cuda ()
2929 packed = pack_uint4 (fake_tensor )
3030 unpacked = unpack_uint4 (packed )
3131
@@ -55,9 +55,9 @@ class W4A16_symmetric_weight_only(torch.nn.Module):
5555 def __init__ (self , scale ):
5656 super ().__init__ ()
5757 assert scale % 4 == 0
58- self .l1 = torch .randint (2 ** 8 - 1 ,(scale , scale ), dtype = torch .uint8 ).cuda ()
58+ self .l1 = torch .randint (2 ** 8 ,(scale , scale ), dtype = torch .uint8 ).cuda ()
5959 self .s1 = torch .tensor ((scale ),dtype = torch .float16 ).cuda ()
60- self .l2 = torch .randint (2 ** 8 - 1 ,(scale // 2 , scale // 4 ), dtype = torch .uint8 ).cuda ()
60+ self .l2 = torch .randint (2 ** 8 ,(scale // 2 , scale // 4 ), dtype = torch .uint8 ).cuda ()
6161 self .s2 = torch .tensor ((scale // 4 ),dtype = torch .float16 ).cuda ()
6262
6363
@@ -79,7 +79,7 @@ def forward(self, x):
7979 b = torch .compile (b , fullgraph = True )
8080
8181 test_input = torch .randn (scale * 2 , dtype = torch .float16 ).cuda ()
82- forward_args = [test_input ]
82+ forward_args = [test_input ]
8383 b .forward (test_input )
8484 print ("scale: " , scale )
8585 print ("fp16 time: " , benchmark (a .forward , forward_args , 100 ))
0 commit comments