| 
 | 1 | +#include "core/conversion/converters/converters.h"  | 
 | 2 | +#include "core/util/prelude.h"  | 
 | 3 | +#include "torch/torch.h"  | 
 | 4 | + | 
 | 5 | +namespace torch_tensorrt {  | 
 | 6 | +namespace core {  | 
 | 7 | +namespace conversion {  | 
 | 8 | +namespace converters {  | 
 | 9 | +namespace impl {  | 
 | 10 | +namespace {  | 
 | 11 | + | 
 | 12 | +auto linear_registrations TORCHTRT_UNUSED = RegisterNodeConversionPatterns().pattern(  | 
 | 13 | +    {"trt::attn_bias_from_attn_mask(Tensor attn_mask) -> Tensor",  | 
 | 14 | +     [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {  | 
 | 15 | +       // Converter for internal op used in unpack_scaled_dot_product_attention  | 
 | 16 | +       // We don't have visibility to check types during lowering and can't introduce conditionals so do type specific  | 
 | 17 | +       // specialization here  | 
 | 18 | +       auto in = args[0].ITensorOrFreeze(ctx);  | 
 | 19 | +       auto out = in;  | 
 | 20 | +       if (in->getType() == nvinfer1::DataType::kBOOL) {  | 
 | 21 | +         auto not_layer = ctx->net->addUnary(*in, nvinfer1::UnaryOperation::kNOT);  | 
 | 22 | +         TORCHTRT_CHECK(not_layer, "Unable to create not layer for attn_bias_from_attn_mask");  | 
 | 23 | +         not_layer->setName((util::node_info(n) + "_not").c_str());  | 
 | 24 | +         auto neg_inf = torch::tensor(-std::numeric_limits<float>::infinity());  | 
 | 25 | +         auto neg_inf_itensor = tensor_to_const(ctx, neg_inf);  | 
 | 26 | +         auto prod_layer = add_elementwise(  | 
 | 27 | +             ctx,  | 
 | 28 | +             nvinfer1::ElementWiseOperation::kPROD,  | 
 | 29 | +             not_layer->getOutput(0),  | 
 | 30 | +             neg_inf_itensor,  | 
 | 31 | +             util::node_info(n) + "_mul");  | 
 | 32 | +         auto add_layer = add_elementwise(  | 
 | 33 | +             ctx, nvinfer1::ElementWiseOperation::kSUM, prod_layer->getOutput(0), in, util::node_info(n) + "_add");  | 
 | 34 | +         out = add_layer->getOutput(0);  | 
 | 35 | +       }  | 
 | 36 | +       auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], out);  | 
 | 37 | +       LOG_DEBUG("Output tensor shape: " << out_tensor->getDimensions());  | 
 | 38 | +       LOG_DEBUG("Output tensor type: " << out_tensor->getType());  | 
 | 39 | +       return true;  | 
 | 40 | +     }});  | 
 | 41 | +} // namespace  | 
 | 42 | +} // namespace impl  | 
 | 43 | +} // namespace converters  | 
 | 44 | +} // namespace conversion  | 
 | 45 | +} // namespace core  | 
 | 46 | +} // namespace torch_tensorrt  | 
0 commit comments