@@ -8,68 +8,86 @@ namespace core {
88namespace lowering {
99namespace passes {
1010
11- void UnpackAndCastMaskedFill (std::shared_ptr<torch::jit::Graph>& graph) {
11+ void UnpackAndCastMaskedFill (std::shared_ptr<torch::jit::Graph>& graph, std::string target_device_name ) {
1212 std::string masked_fill_pattern = R"IR(
1313 graph(%self, %mask, %value):
1414 %out: Tensor = aten::masked_fill_(%self, %mask, %value)
1515 return (%out))IR" ;
1616
1717 // Calls to masked_fill_ often utilize CPU tensors, and as such
18- // should be casted to CUDA to avoid device mismatch errors
19- std::string unpacked_pattern = R"IR(
18+ // should be moved to gpu to avoid device mismatch errors
19+
20+ // Separate string into portions to insert device name
21+ std::string clean_pattern_part_1 = R"IR(
2022 graph(%self, %mask, %value):
21- %device: Device = prim::Constant[value="cuda"]()
23+ %device: Device = prim::Constant[value=")IR" ;
24+
25+ std::string clean_pattern_part_2 = R"IR( "]()
2226 %dtype: NoneType = prim::Constant()
2327 %false: bool = prim::Constant[value=0]()
2428 %mask_cuda: Tensor = aten::to(%mask, %device, %dtype, %false, %false)
2529 %self_cuda: Tensor = aten::to(%self, %device, %dtype, %false, %false)
26- %out: Tensor = aten::masked_fill_ (%self_cuda, %mask_cuda, %value)
30+ %out: Tensor = aten::masked_fill (%self_cuda, %mask_cuda, %value)
2731 return (%out))IR" ;
2832
33+ auto unpacked_pattern = clean_pattern_part_1 + target_device_name + clean_pattern_part_2;
34+
2935 torch::jit::SubgraphRewriter masked_fill_rewriter;
3036 masked_fill_rewriter.RegisterRewritePattern (masked_fill_pattern, unpacked_pattern);
3137 masked_fill_rewriter.runOnGraph (graph);
3238 LOG_GRAPH (" After unpack and cast masked_fill_: " << *graph);
3339}
3440
35- void UnpackAndCastNumToTensor (std::shared_ptr<torch::jit::Graph>& graph) {
41+ void UnpackAndCastNumToTensor (std::shared_ptr<torch::jit::Graph>& graph, std::string target_device_name ) {
3642 std::string num_to_tensor_cast_pattern = R"IR(
3743 graph(%1: Scalar):
3844 %2: Tensor = prim::NumToTensor(%1)
3945 return (%2))IR" ;
4046
41- // 0D Tensors are initialized on cpu, and need to be casted to CUDA
47+ // 0D Tensors are initialized on cpu, and need to be moved to gpu
4248 // to avoid device mismatch issues
43- std::string num_to_tensor_clean_pattern = R"IR(
49+
50+ // Separate string into portions to insert device name
51+ std::string clean_pattern_part_1 = R"IR(
4452 graph(%1: Scalar):
4553 %2: Tensor = prim::NumToTensor(%1)
46- %device: Device = prim::Constant[value="cuda"]()
54+ %device: Device = prim::Constant[value=")IR" ;
55+
56+ std::string clean_pattern_part_2 = R"IR( "]()
4757 %dtype: NoneType = prim::Constant()
4858 %false: bool = prim::Constant[value=0]()
4959 %3: Tensor = aten::to(%2, %device, %dtype, %false, %false)
5060 return (%3))IR" ;
5161
62+ auto num_to_tensor_clean_pattern = clean_pattern_part_1 + target_device_name + clean_pattern_part_2;
63+
5264 torch::jit::SubgraphRewriter num_to_tensor_cast_rewriter;
5365 num_to_tensor_cast_rewriter.RegisterRewritePattern (num_to_tensor_cast_pattern, num_to_tensor_clean_pattern);
5466 num_to_tensor_cast_rewriter.runOnGraph (graph);
5567
5668 LOG_GRAPH (" After unpack and cast NumToTensor: " << *graph);
5769}
5870
59- void UnpackAndCastFull (std::shared_ptr<torch::jit::Graph>& graph) {
71+ void UnpackAndCastFull (std::shared_ptr<torch::jit::Graph>& graph, std::string target_device_name ) {
6072 std::string full_cast_pattern = R"IR(
6173 graph(%1, %2, %3, %4, %5, %6):
6274 %out: Tensor = aten::full(%1, %2, %3, %4, %5, %6)
6375 return (%out))IR" ;
6476
65- // Tensors created via aten::full are initialized on cpu, and need to be casted to CUDA
77+ // Tensors created via aten::full are initialized on cpu, and need to be casted to gpu
6678 // to avoid device mismatch issues
67- std::string full_clean_pattern = R"IR(
79+
80+ // Separate string into portions to insert device name
81+ std::string clean_pattern_part_1 = R"IR(
6882 graph(%1, %2, %3, %4, %5, %6):
69- %cuda: Device = prim::Constant[value="cuda"]()
70- %out: Tensor = aten::full(%1, %2, %3, %4, %cuda, %6)
83+ %device: Device = prim::Constant[value=")IR" ;
84+
85+ std::string clean_pattern_part_2 = R"IR( "]()
86+ %out: Tensor = aten::full(%1, %2, %3, %4, %device, %6)
7187 return (%out))IR" ;
7288
89+ auto full_clean_pattern = clean_pattern_part_1 + target_device_name + clean_pattern_part_2;
90+
7391 torch::jit::SubgraphRewriter full_cast_rewriter;
7492 full_cast_rewriter.RegisterRewritePattern (full_cast_pattern, full_clean_pattern);
7593 full_cast_rewriter.runOnGraph (graph);
0 commit comments