File tree Expand file tree Collapse file tree 1 file changed +1
-7
lines changed
cpp/tensorrt_llm/plugins/fusedLayernormPlugin Expand file tree Collapse file tree 1 file changed +1
-7
lines changed Original file line number Diff line number Diff line change @@ -82,9 +82,6 @@ nvinfer1::DimsExprs FusedLayernormPlugin::getOutputDimensions(
8282 {
8383 ret.d [di] = inputs[0 ].d [di];
8484 }
85- // Div up by 16 as the storage type has 16 FP4 values per element.
86- ret.d [ret.nbDims - 1 ]
87- = exprBuilder.operation (DimensionOperation::kCEIL_DIV , *ret.d [ret.nbDims - 1 ], *exprBuilder.constant (16 ));
8885 return ret;
8986 }
9087
@@ -105,11 +102,8 @@ nvinfer1::DimsExprs FusedLayernormPlugin::getOutputDimensions(
105102 = exprBuilder.operation (DimensionOperation::kCEIL_DIV , *ret.d [ret.nbDims - 2 ], *exprBuilder.constant (128 ));
106103 ret.d [ret.nbDims - 2 ] = exprBuilder.operation (DimensionOperation::kPROD , *dimM, *exprBuilder.constant (128 ));
107104 // Hidden size dimension.
108- // Div (rounding up) by 64.
109- // 16 elements share one SF, and we need number of SFs to be multiple of 4.
110- // The output data type is already int32_t (4 SFs), so just need to div (rounding up) by 64.
111105 ret.d [ret.nbDims - 1 ]
112- = exprBuilder.operation (DimensionOperation::kCEIL_DIV , *ret.d [ret.nbDims - 1 ], *exprBuilder.constant (64 ));
106+ = exprBuilder.operation (DimensionOperation::kCEIL_DIV , *ret.d [ret.nbDims - 1 ], *exprBuilder.constant (16 ));
113107 return ret;
114108 }
115109 catch (std::exception const & e)
You can’t perform that action at this time.
0 commit comments