@@ -418,10 +418,10 @@ def _(func, types, args, kwargs):
418418
419419@implements (aten .slice .Tensor )
420420def _ (func , types , args , kwargs ):
421- """Only supports slicing for dim == 1 and dim == 2
422- original tensor shape has dimension (N, K)
423- qdata has dimension (N, K)
424- scale (per row quantization) has dimension: (N,)
421+ """Supports slicing for 1d, 2d, and 3d tensors
422+ original tensor shape has dimension (N, K), or (B, N, K)
423+ qdata has dimension (N, K) or (B, N, K)
424+ scale (per row quantization) has dimension: (N,) or (B, N)
425425
426426 since qdata has the same dimension as original tensor, we can directly slice that
427427 for scale, we'll do a slice when dim is 0, and don't need to do anything for dim 1
@@ -431,12 +431,12 @@ def _(func, types, args, kwargs):
431431 """
432432 self , dim , start , end , step = fill_defaults (args , 5 , [0 , None , None , 1 ])
433433 assert step == 1
434- assert dim == 0 or dim == 1 , f"Only dim==0 or 1 are supported, got: { dim } "
434+ assert dim == 0 or dim == 1 or dim == 2 , f"Only dim==0,1,2 are supported, got: dim= { dim } "
435435 if end >= self .shape [dim ]:
436436 end = self .shape [dim ]
437437
438- assert self .qdata .ndim == 2 , (
439- f"Expected packed weight to have dim 2, got { self .qdata .dim } "
438+ assert self .qdata .ndim == 2 or self . qdata . ndim == 3 , (
439+ f"Expected packed weight to have dim==2,3 got: dim= { self .qdata .ndim } "
440440 )
441441
442442 # Always slice the qdata
@@ -638,6 +638,34 @@ def _(func, types, args, kwargs):
638638 )
639639 return return_and_correct_aliasing (func , args , kwargs , new_float8_tensor )
640640
641+ @implements (aten .unsqueeze .default )
642+ def _ (func , types , args , kwargs ):
643+ self , dim = args
644+ assert dim == 0 , f"Only dim == 0 is supported, got: { dim } "
645+ qdata = self .qdata .unsqueeze (dim = dim )
646+ scale = self .scale .unsqueeze (dim = dim )
647+ block_size = []
648+ for i in range (len (qdata .shape )):
649+ block_size .append (qdata .shape [i ] // scale .shape [i ])
650+
651+ new = self .__class__ (
652+ qdata ,
653+ scale ,
654+ block_size ,
655+ self .mm_config ,
656+ self .act_quant_kwargs ,
657+ self .kernel_preference ,
658+ self .dtype ,
659+ )
660+ return return_and_correct_aliasing (func , args , kwargs , new )
661+
662+
663+ @implements (aten .add .Tensor )
664+ def _ (func , types , args , kwargs ):
665+ assert len (args ) == 2 , f"Expected 2 args, got { len (args )} "
666+ assert isinstance (args [0 ], torch .Tensor ) and isinstance (args [1 ], Float8Tensor ), f"Expected args[0]==torch.Tensor and args[1]==Float8Tensor, got { type (args [0 ]), type (args [1 ])} "
667+ sum_tensor = args [0 ] + args [1 ].dequantize ()
668+ return return_and_correct_aliasing (func , args , kwargs , sum_tensor )
641669
642670Float8Tensor .__module__ = "torchao.quantization"
643671
0 commit comments