diff --git a/src/libtorch_utils.cc b/src/libtorch_utils.cc index bd7353b..f9cf227 100644 --- a/src/libtorch_utils.cc +++ b/src/libtorch_utils.cc @@ -50,6 +50,8 @@ ConvertTorchTypeToDataType(const torch::ScalarType& stype) return TRITONSERVER_TYPE_FP32; case torch::kDouble: return TRITONSERVER_TYPE_FP64; + case torch::kBFloat16: + return TRITONSERVER_TYPE_BF16; default: break; } @@ -89,6 +91,9 @@ ConvertDataTypeToTorchType(const TRITONSERVER_DataType dtype) case TRITONSERVER_TYPE_FP64: type = torch::kDouble; break; + case TRITONSERVER_TYPE_BF16: + type = torch::kBFloat16; + break; case TRITONSERVER_TYPE_UINT16: case TRITONSERVER_TYPE_UINT32: case TRITONSERVER_TYPE_UINT64: @@ -130,6 +135,8 @@ ModelConfigDataTypeToTorchType(const std::string& data_type_str) type = torch::kFloat; } else if (dtype == "FP64") { type = torch::kDouble; + } else if (dtype == "BF16") { + type = torch::kBFloat16; } else { return std::make_pair(false, type); }